In [1]:
import collections
import os
import multiprocessing
import sys

import shap
import pandas as pd
import numpy as np
import lightgbm as lgb
import sklearn.metrics
import ml_metrics
import tqdm
from gensim.models import Word2Vec

In [2]:
# Reproducibility parameters
SEED = 281989

# Explainability parameters
EXPLAIN_FROM_SAMPLES = 1000

# LightGBM parameters
LIGHTGBM_LEARNING_RATE = 0.2
LIGHTGBM_NUM_BOOST_ROUND = 50

# Word2Vec parameters
WORD2VEC_EMBEDDING_SIZE = 32
WORD2VEC_EPOCHS = 10

# Evaluation parameters
CROSS_VALIDATION_K = 5
MAP_K = 10
EVAL_USERS_LIMIT = 100

In [3]:
# Data description: https://www.kaggle.com/c/kkbox-music-recommendation-challenge/data
data_dir = 'data'

In [4]:
train_columns = {
  'msno': 'category',
  'song_id': 'category',
  'source_system_tab': 'category',
  'source_screen_name': 'category',
  'source_type': 'category',
  'target': np.float,
}
train_df_raw = pd.read_csv(os.path.join(data_dir, 'train.csv'), dtype=train_columns)
print(train_df_raw.dtypes)
train_df_raw.head()

msno                  category
song_id               category
source_system_tab     category
source_screen_name    category
source_type           category
target                 float64
dtype: object


Unnamed: 0,msno,song_id,source_system_tab,source_screen_name,source_type,target
0,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,BBzumQNXUHKdEBOB7mAJuzok+IJA1c2Ryg/yzTF6tik=,explore,Explore,online-playlist,1.0
1,Xumu+NIjS6QYVxDS4/t3SawvJ7viT9hPKXmf0RtLNx8=,bhp/MpSNoqoxOIB+/l8WPqu6jldth4DIpCm3ayXnJqM=,my library,Local playlist more,local-playlist,1.0
2,Xumu+NIjS6QYVxDS4/t3SawvJ7viT9hPKXmf0RtLNx8=,JNWfrrC7zNN7BdMpsISKa4Mw+xVJYNnxXh3/Epw7QgY=,my library,Local playlist more,local-playlist,1.0
3,Xumu+NIjS6QYVxDS4/t3SawvJ7viT9hPKXmf0RtLNx8=,2A87tzfnJTSWqD7gIZHisolhe4DMdzkbd6LzO1KHjNs=,my library,Local playlist more,local-playlist,1.0
4,FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg=,3qm6XTZ6MOCU11x8FIVbAGH5l5uMkT3/ZalWG1oo2Gc=,explore,Explore,online-playlist,1.0


In [5]:
test_columns = {
  'id': np.int,
  'msno': 'category',
  'song_id': 'category',
  'source_system_tab': 'category',
  'source_screen_name': 'category',
  'source_type': 'category',
}
test_df_raw = pd.read_csv(os.path.join(data_dir, 'test.csv'), dtype=test_columns)
print(test_df_raw.dtypes)
test_df_raw.head()

id                       int64
msno                  category
song_id               category
source_system_tab     category
source_screen_name    category
source_type           category
dtype: object


Unnamed: 0,id,msno,song_id,source_system_tab,source_screen_name,source_type
0,0,V8ruy7SGk7tDm3zA51DPpn6qutt+vmKMBKa21dp54uM=,WmHKgKMlp1lQMecNdNvDMkvIycZYHnFwDT72I5sIssc=,my library,Local playlist more,local-library
1,1,V8ruy7SGk7tDm3zA51DPpn6qutt+vmKMBKa21dp54uM=,y/rsZ9DC7FwK5F2PK2D5mj+aOBUJAjuu3dZ14NgE0vM=,my library,Local playlist more,local-library
2,2,/uQAlrAkaczV+nWCd2sPF2ekvXPRipV7q0l+gbLuxjw=,8eZLFOdGVdXBSqoAv5nsLigeH2BvKXzTQYtUM53I0k4=,discover,,song-based-playlist
3,3,1a6oo/iXKatxQx4eS9zTVD+KlSVaAFbTIqVvwLC1Y0k=,ztCf8thYsS4YN3GcIL/bvoxLm/T5mYBVKOO4C9NiVfQ=,radio,Radio,radio
4,4,1a6oo/iXKatxQx4eS9zTVD+KlSVaAFbTIqVvwLC1Y0k=,MKVMpslKcQhMaFEgcEQhEfi5+RZhMYlU3eRDpySrH8Y=,radio,Radio,radio


In [6]:
songs_columns = {
  'song_id': 'category',
  'song_length': np.int,
  'genre_ids': 'category',
  'artist_name': 'category',
  'composer': 'category',
  'lyricist': 'category',
  'language': 'category',
}
songs_df_raw = pd.read_csv(os.path.join(data_dir, 'songs.csv'), dtype=songs_columns)
print(songs_df_raw.dtypes)
songs_df_raw.head()

song_id        category
song_length       int64
genre_ids      category
artist_name    category
composer       category
lyricist       category
language       category
dtype: object


Unnamed: 0,song_id,song_length,genre_ids,artist_name,composer,lyricist,language
0,CXoTN1eb7AI+DntdU1vbcwGRV4SCIDxZu+YD8JP8r4E=,247640,465,張信哲 (Jeff Chang),董貞,何啟弘,3.0
1,o0kFgae9QtnYgRkVPqLJwa05zIhRlUjfF7O1tDw0ZDU=,197328,444,BLACKPINK,TEDDY| FUTURE BOUNCE| Bekuh BOOM,TEDDY,31.0
2,DwVvVurfpuz+XPuFvucclVQEyPqcpUkHR0ne1RQzPs0=,231781,465,SUPER JUNIOR,,,31.0
3,dKMBWoZyScdxSkihKG+Vf47nc18N9q4m58+b4e7dSSE=,273554,465,S.H.E,湯小康,徐世珍,3.0
4,W3bqWd3T+VeHFzHAUfARgW9AvVRaF4N5Yzm4Mr6Eo/o=,140329,726,貴族精選,Traditional,Traditional,52.0


In [7]:
members_columns = {
  'msno': 'category',
  'city': 'category',
  'bd': np.int,
  'gender': 'category',
  'registered_via': 'category',
  'registration_init_time': str,
  'expiration_date': str,
}
members_df_raw = pd.read_csv(os.path.join(data_dir, 'members.csv'), dtype=members_columns)
print(members_df_raw.dtypes)
members_df_raw.head()

msno                      category
city                      category
bd                           int64
gender                    category
registered_via            category
registration_init_time      object
expiration_date             object
dtype: object


Unnamed: 0,msno,city,bd,gender,registered_via,registration_init_time,expiration_date
0,XQxgAYj3klVKjR3oxPPXYYFp4soD4TuBghkhMTD4oTw=,1,0,,7,20110820,20170920
1,UizsfmJb9mV54qE9hCYyU07Va97c0lCRLEQX3ae+ztM=,1,0,,7,20150628,20170622
2,D8nEhsIOBSoE6VthTaqDX8U6lqjJ7dLdr72mOyLya2A=,1,0,,4,20160411,20170712
3,mCuD+tZ1hERA/o5GPqk38e041J8ZsBaLcu7nGoIIvhI=,1,0,,9,20150906,20150907
4,q4HRBfVSssAFS9iRfxWrohxuk9kCYMKjHOEagUMV6rQ=,1,0,,4,20170126,20170613


In [8]:
song_extra_info_columns = {
  'song_id': 'category',
  'name': 'category',
  'isrc': 'category',
}
song_extra_info_df_raw = pd.read_csv(os.path.join(data_dir, 'song_extra_info.csv'), dtype=song_extra_info_columns)
print(song_extra_info_df_raw.dtypes)
song_extra_info_df_raw.head()

song_id    category
name       category
isrc       category
dtype: object


Unnamed: 0,song_id,name,isrc
0,LP7pLJoJFBvyuUwvu+oLzjT+bI+UeBPURCecJsX1jjs=,我們,TWUM71200043
1,ClazTFnk6r0Bnuie44bocdNMM3rdlrq0bCGAsGUWcHE=,Let Me Love You,QMZSY1600015
2,u2ja/bZE3zhCGxvbbOB3zOoUjx27u40cf5g09UXMoKQ=,原諒我,TWA530887303
3,92Fqsy0+p6+RHe2EoLKjHahORHR1Kq1TBJoClW9v+Ts=,Classic,USSM11301446
4,0QFmz/+rJy1Q56C1DuYqT9hKKqi5TUqx0sN0IwvoHrw=,愛投羅網,TWA471306001


In [9]:
song_columns_to_use = ['song_id', 'artist_name', 'genre_ids', 'song_length', 'language']
train_df = train_df_raw.merge(songs_df_raw[song_columns_to_use], on='song_id', how='left')
test_df = test_df_raw.merge(songs_df_raw[song_columns_to_use], on='song_id', how='left')

In [10]:
def parse_yearmonthday(date):
  parts = date[0:4], date[4:6], date[6:8]
  return list(map(int, parts))

In [11]:
members_df = members_df_raw.copy()

In [12]:
reg_init_label = 'registration_init_time'
members_df['registration_year'] = members_df[reg_init_label].apply(lambda date: parse_yearmonthday(date)[0])
members_df['registration_month'] = members_df[reg_init_label].apply(lambda date: parse_yearmonthday(date)[1])
members_df['registration_day'] = members_df[reg_init_label].apply(lambda date: parse_yearmonthday(date)[2])
members_df = members_df.drop([reg_init_label], axis=1)

In [13]:
exp_label = 'expiration_date'
members_df['expiration_year'] = members_df[exp_label].apply(lambda date: parse_yearmonthday(date)[0])
members_df['expiration_month'] = members_df[exp_label].apply(lambda date: parse_yearmonthday(date)[1])
members_df['expiration_day'] = members_df[exp_label].apply(lambda date: parse_yearmonthday(date)[2])
members_df = members_df.drop([exp_label], axis=1)

In [14]:
train_df = train_df.merge(members_df, on='msno', how='left')
test_df = test_df.merge(members_df, on='msno', how='left')

In [15]:
def isrc_to_year(isrc):
  if type(isrc) == str:
    if int(isrc[5:7]) > 17:
      return 1900 + int(isrc[5:7])
    else:
      return 2000 + int(isrc[5:7])
  else:
    return np.nan

song_extra_info_df = song_extra_info_df_raw.copy()
song_extra_info_df['song_year'] = song_extra_info_df['isrc'].apply(isrc_to_year)
song_extra_info_df.drop(['isrc', 'name'], axis=1, inplace=True)

In [16]:
train_df = train_df.merge(song_extra_info_df, on='song_id', how='left')
test_df = test_df.merge(song_extra_info_df, on='song_id', how='left')

In [17]:
for col in train_df.columns:
  if train_df[col].dtype == object:
    train_df[col] = train_df[col].astype('category')
    test_df[col] = test_df[col].astype('category')

In [18]:
def train_lightgbm(X_train, y_train):
  D_train = lgb.Dataset(X_train, y_train)
  watchlist = [D_train]

  params = {
    'learning_rate': LIGHTGBM_LEARNING_RATE,
    'objective': 'binary',
    'max_depth': 8,
    'num_leaves': 2**8,
    'verbosity': 0,
    'metric': 'auc',
  }

  print('Training the model')
  model = lgb.train(
    params,
    train_set=D_train,
    num_boost_round=LIGHTGBM_NUM_BOOST_ROUND,
    valid_sets=watchlist,
    verbose_eval=10)
  return model

In [19]:
def lightgbm_rank_by_user(model, user_id, songs_df):
  song_ids = songs_df.song_id.tolist()
  preds = model.predict(songs_df)
  indexes = list(range(len(song_ids)))
  indexes.sort(key=lambda index: preds[index], reverse=True)
  return [song_ids[index] for index in indexes]

In [20]:
ModelEvaluation = collections.namedtuple('ModelEvaluation', ['mapk'])

def evaluate_model_ranking_fn(ranking_fn, X_val, y_val) -> ModelEvaluation:
  print('Evaluating the model')
  sys.stdout.flush()
  val_user_ids = X_val.msno.unique().tolist()
  X_val_target_1 = X_val[y_val == 1.0]
  apks = []
  for user_id in tqdm.tqdm(val_user_ids[:EVAL_USERS_LIMIT], file=sys.stdout):
    songs_df = X_val[X_val.msno == user_id]
    predicted = ranking_fn(user_id, songs_df)
    actual = X_val_target_1[X_val_target_1.msno == user_id].song_id.tolist()
    apk = ml_metrics.apk(actual, predicted, MAP_K)
    apks.append(apk)
  sys.stdout.flush()
  evaluation = ModelEvaluation(mapk=np.mean(apks))
  print(evaluation)
  return evaluation

In [21]:
def k_fold_cross_validation(train_df, k, train_model_fn, preprocess_val_fn, build_ranking_fn):
  indexes = np.array(train_df.index)
  np.random.seed(SEED)
  np.random.shuffle(indexes)
  X_train_splits = np.array_split(indexes, k)
  
  evaluations = []
  for split_index, val_indexes in enumerate(X_train_splits):
    print('*' * 100)
    print(f'Taking split number {split_index} as the validation set...')
    train_indexes = np.setdiff1d(indexes, val_indexes)
    print('Train indexes:', len(train_indexes))
    print('Val indexes:', len(val_indexes))

    X_train = train_df.iloc[train_indexes].drop(['target'], axis=1)
    y_train = train_df.iloc[train_indexes]['target'].values
    X_val = train_df.iloc[val_indexes].drop(['target'], axis=1)
    y_val = train_df.iloc[val_indexes]['target'].values
    
    model = train_model_fn(X_train, y_train)
    if preprocess_val_fn:
      X_val = preprocess_val_fn(model, X_val)
    evaluation = evaluate_model_ranking_fn(build_ranking_fn(model), X_val, y_val)
    evaluations.append(evaluation)
  return evaluations

In [22]:
evaluations = k_fold_cross_validation(
  train_df,
  k=CROSS_VALIDATION_K,
  train_model_fn=train_lightgbm,
  preprocess_val_fn=None,
  build_ranking_fn=lambda model: lambda user_id, songs_df: lightgbm_rank_by_user(model, user_id, songs_df))
mapks = list(map(lambda evaluation: evaluation.mapk, evaluations))
np.mean(mapks)

****************************************************************************************************
Taking split number 0 as the validation set...
Train indexes: 5901934
Val indexes: 1475484
Training the model
[10]	training's auc: 0.719914
[20]	training's auc: 0.735223
[30]	training's auc: 0.741772
[40]	training's auc: 0.746365
[50]	training's auc: 0.7503
Evaluating the model
100%|██████████| 100/100 [03:50<00:00,  2.30s/it]
ModelEvaluation(mapk=0.6497512959813556)
****************************************************************************************************
Taking split number 1 as the validation set...
Train indexes: 5901934
Val indexes: 1475484
Training the model
[10]	training's auc: 0.719796
[20]	training's auc: 0.732967
[30]	training's auc: 0.739175
[40]	training's auc: 0.744729
[50]	training's auc: 0.74939
Evaluating the model
100%|██████████| 100/100 [03:51<00:00,  2.32s/it]
ModelEvaluation(mapk=0.6371305555555555)
*********************************************************

0.6470269800327539

In [23]:
X_train = train_df.drop(['target'], axis=1)
y_train = train_df['target'].values
model = train_lightgbm(X_train, y_train)
evaluation = evaluate_model_ranking_fn(
  lambda user_id, songs_df: lightgbm_rank_by_user(model, user_id, songs_df),
  X_train,
  y_train)

Training the model
[10]	training's auc: 0.719427
[20]	training's auc: 0.732026
[30]	training's auc: 0.739266
[40]	training's auc: 0.743502
[50]	training's auc: 0.750068
Evaluating the model
100%|██████████| 100/100 [03:54<00:00,  2.34s/it]
ModelEvaluation(mapk=0.8176242063492063)


In [24]:
np.random.seed(SEED)
explain_indexes = np.random.choice(X_train.index, EXPLAIN_FROM_SAMPLES, replace=False)
X_explain = X_train.iloc[explain_indexes]

In [25]:
explainer = shap.TreeExplainer(model, feature_perturbation='tree_path_dependent')
shap_values = explainer.shap_values(X_explain)

LightGBM binary classifier with TreeExplainer shap values output has changed to a list of ndarray


In [26]:
index = 0
shap.initjs()
shap.force_plot(
  explainer.expected_value[0],
  shap_values[0][index,:],
  X_explain.iloc[index,:])

In [27]:
X_test = test_df.drop(['id'], axis=1)
test_ids = test_df['id'].values

p_test = model.predict(X_test)

submission_df = pd.DataFrame()
submission_df['id'] = test_ids
submission_df['target'] = p_test
submission_df.to_csv('lightgbm_submission_df.csv', index=False, float_format='%.6f')

# Nonclassical embeddings: Word2Vec

In [28]:
user_embeddings = None

def rebuild_user_embeddings(wv, X_train, y_train):
  global user_embeddings
  user_ids = X_train.msno.unique().tolist()
  user_embeddings = collections.defaultdict(lambda: np.zeros(WORD2VEC_EMBEDDING_SIZE))
  user_counts = collections.defaultdict(int)
  print('Building user embeddings')
  for (index, row), target in tqdm.tqdm(zip(X_train.iterrows(), y_train), file=sys.stdout):
    if target != 1.0:
      continue
    user_id = row['msno']
    song_id = row['song_id']
    if song_id in wv:
      user_embeddings[user_id] += wv[song_id]
      user_counts[user_id] += 1
  for user_id in user_embeddings.keys():
    user_embeddings[user_id] /= user_counts[user_id]

def train_word2vec(X_train, y_train):
  user_ids = X_train.msno.unique().tolist()
  print('Users:', len(user_ids))
  
  print('Building sessions')
  sessions = []
  for user_id in user_ids:
    sessions.append(X_train[X_train['msno'] == user_id]['song_id'].tolist())

  word2vec = Word2Vec(
    window=10, size=WORD2VEC_EMBEDDING_SIZE, sg=1, hs=0, negative=10, alpha=0.03, min_alpha=0.0007,
    workers=multiprocessing.cpu_count(), seed=SEED)
  
  print('Building vocabulary')
  word2vec.build_vocab(sessions, progress_per=1000)
  
  print('Training word2vec')
  word2vec.train(sessions, total_examples=word2vec.corpus_count, epochs=WORD2VEC_EPOCHS, report_delay=1)
  print(word2vec)
  
  word2vec.init_sims(replace=True)
  rebuild_user_embeddings(word2vec.wv, X_train, y_train)
  return word2vec

In [29]:
word2vec = train_word2vec(X_train, y_train)

Users: 30755
Building sessions
Building vocabulary
Training word2vec
Word2Vec(vocab=101648, size=32, alpha=0.03)
Building user embeddings
7377418it [06:48, 18076.52it/s]


## Word2Vec similars

In [30]:
def similar_song_ids(song_id, n=10):
  song_embedding = word2vec.wv[song_id]
  similars = word2vec.wv.similar_by_vector(song_embedding, topn=n+1)[1:]
  return similars

In [31]:
song_id = 'CXoTN1eb7AI+DntdU1vbcwGRV4SCIDxZu+YD8JP8r4E='
print(song_extra_info_df_raw[song_extra_info_df_raw.song_id == song_id])

                                             song_id name          isrc
200812  CXoTN1eb7AI+DntdU1vbcwGRV4SCIDxZu+YD8JP8r4E=   焚情  TWB531410010


In [32]:
for similar_id, similarity in similar_song_ids(song_id):
  print(song_extra_info_df_raw[song_extra_info_df_raw.song_id == similar_id])
  print('Similarity is', similarity)

                                            song_id   name          isrc
81585  1CLJhlrmGXmFNrFwH2L6fOMcp4bwYQMJPBbUEitZd2k=  愛你的宿命  TWB531410001
Similarity is 0.8702131509780884
                                             song_id  name          isrc
412187  sGjtw+Dl6XPOd4inrRiEzv1pCeQN9Wtml6egb0CZw2I=  上海姑娘  TWB530610008
Similarity is 0.8310494422912598
                                             song_id name          isrc
152575  pbjE7VuJhc7fiECF0GEp6kNvopMAAskldFR2zS5ThG0=  情歌手  TWA130100804
Similarity is 0.8216755390167236
                                             song_id name          isrc
131161  r5q9D1UeT2JNPFHnVkAU5PeL/8v8ord5CJDI9JQ3US0=   柔軟  TWB531410003
Similarity is 0.8211809992790222
                                             song_id name          isrc
179216  i2r2H9csf9O8WcqfBSluK/Q8/rPFgd+/Ba19s7AbRhY=   自從  TWA130200301
Similarity is 0.8155577778816223
                                             song_id  \
300076  iq0pljzLSoNxy9+hPpRob3Iu2oOB+27B/OELxjh6WWg=   

## Word2Vec as a recommender system

In [33]:
def get_user_word2vec_embedding(user_id):
  global user_embeddings
  return user_embeddings[user_id]

In [34]:
sample_user_id = 'FGtllVqz18RPiwJj/edr2gV78zirAiY/9SmYvia+kCg='
sample_user_embedding = get_user_word2vec_embedding(sample_user_id)
sample_user_embedding

array([-0.18062335, -0.09155007,  0.12431272,  0.02828265,  0.02962735,
        0.08878952,  0.12603288,  0.08118049, -0.13190042, -0.03563733,
        0.02485227, -0.07927597, -0.01223744,  0.12591707,  0.09178266,
        0.0182774 , -0.17914344, -0.14955795,  0.13865888, -0.05123768,
        0.22304841,  0.19278718,  0.15587596, -0.01156536,  0.03680101,
        0.01595459,  0.05715868,  0.20247165, -0.05994776,  0.20214146,
        0.06441861,  0.08595817])

In [35]:
def word2vec_similarity(user_embedding, song_embedding):
  user_embedding_norm = np.linalg.norm(user_embedding)
  song_embedding_norm = np.linalg.norm(song_embedding)
  if user_embedding_norm < 1e-9 or song_embedding_norm < 1e-9:
    similarity = np.nan
  else:
    similarity = np.dot(user_embedding, song_embedding) / (user_embedding_norm * song_embedding_norm)
  return similarity

In [36]:
def word2vec_rank_by_user(wv, user_id, songs_df):
  song_ids = songs_df.song_id.tolist()
  user_embedding = get_user_word2vec_embedding(user_id)
  similarities = []
  for song_id in song_ids:
    song_embedding = wv[song_id] if song_id in wv else np.zeros(WORD2VEC_EMBEDDING_SIZE)
    similarity = word2vec_similarity(user_embedding, song_embedding)
    similarities.append(similarity)
  
  indexes = list(range(len(song_ids)))
  indexes.sort(key=lambda index: similarities[index], reverse=True)
  # TODO(niksaz): Remove this debug output.
  print([similarities[index] for index in indexes])
  return [song_ids[index] for index in indexes]

In [37]:
evaluation = evaluate_model_ranking_fn(
  lambda user_id, songs_df: word2vec_rank_by_user(word2vec.wv, user_id, songs_df),
  X_train,
  y_train)

Evaluating the model
  0%|          | 0/100 [00:00<?, ?it/s][0.8403348488219571, 0.8372094453374771, 0.8316751162553673, 0.8295715352725928, 0.8259724807174932, 0.8155157913147981, 0.812831290460857, 0.8116178927998291, 0.809495891005319, 0.8068232243578378, 0.8023507164400442, 0.8015980026252743, 0.7827639797190786, 0.7825045840894533, 0.7775649665318057, 0.7761423259896509, 0.7746737301389255, 0.7717582445327662, 0.7712479758647394, 0.7684897337245494, 0.7683661812086965, 0.7674490321557714, 0.7652463036159205, 0.7554248741930392, 0.7533015519529948, 0.748115243683336, 0.7440059610158608, 0.7428272502163711, 0.7422521652290875, 0.7415247932432506, 0.7414856326026671, 0.7405499060990787, 0.7405437878730002, 0.7385547450139298, 0.7358625664945478, 0.7336399797626013, 0.7333623790604524, 0.7312754492807341, 0.7263104650553815, 0.7259949600506548, 0.7255922803858273, 0.7253038225037337, 0.7247457229776296, 0.7209606328831706, 0.7206786319975417, 0.7200410823144509, 0.719833115043763, 0.7

  4%|▍         | 4/100 [00:00<00:02, 39.17it/s][0.8773800507726064, 0.8625584734665844, 0.8612113940937318, 0.8585526816969594, 0.8581237603871916, 0.8580178502473581, 0.8538181657044386, 0.8464969382880606, 0.8464411872289541, 0.8427745923719147, 0.8426715198911142, 0.8395964226335022, 0.8395816414444054, 0.8372909711982196, 0.8349661346501464, 0.8299317283549418, 0.8223294723023717, 0.821740690951784, 0.8146616303374136, 0.8083514238637175, 0.7975544479170583, 0.7925142975320727, 0.7877156936238714, 0.7845947581819716, 0.7796304511299529, 0.7767874669518726, 0.7716682179283094, 0.7684376318219593, 0.7672214017783363, 0.7628053246424574, 0.7600237027320049, 0.7588705757183563, 0.757267335404591, 0.7412263387700473, 0.7267623536265959, 0.7244925398599394, 0.7244077092162178, 0.7234395642840644, 0.7183524938686675, 0.7178534297177356, 0.7157182211411274, 0.714827001788841, 0.7116912444282409, 0.7087130858307148, 0.7085012670665816, 0.7069034422593712, 0.6962884089004909, 0.6902438160907

[nan, nan, 0.8156188629271206, 0.7955004869270312, 0.7950367148016867, 0.7945945361335096, 0.7817799899216824, 0.7806715130292804, 0.7669861267561506, 0.7650331788326038, 0.7611884076571576, 0.7503129628523278, 0.7496164439755253, 0.7472013276827957, 0.7460743455515352, 0.7415022489831606, 0.7371557739331054, 0.7370378200392397, 0.7192446317783682, 0.7182404703850864, 0.7178314136744721, 0.7152252811462224, 0.713785177882054, 0.7135686002960365, 0.7133787636223484, 0.7066818725917315, 0.7008724627505707, 0.6990742199278059, 0.6963800575794319, 0.6948678412787723, 0.694574955857034, 0.6934305179637948, 0.6879288124838923, 0.6874558459809176, 0.6854946330562227, 0.685302007652638, 0.6845422705360029, 0.6833565835732688, 0.6820084175891801, 0.6818789965168616, 0.6793618244827917, 0.6764581036315144, 0.6751068433014434, 0.6722186889527678, 0.6720750092682808, 0.6686068228738964, 0.6652754106573013, 0.6633058958002436, 0.658624127980045, 0.6567538457735184, 0.6521533052682275, 0.65123711071

 13%|█▎        | 13/100 [00:00<00:01, 47.07it/s][nan, 0.9131458081899093, 0.9076236484777852, 0.9033999681086741, 0.9032171859785885, 0.9030849258902244, 0.8988544830820346, 0.897249685701704, 0.8966442530478553, 0.8944199167778946, 0.8900896405885614, 0.8880350737240236, 0.887955478000082, 0.8866930778627982, 0.885813988694327, 0.8836189685181178, 0.883242639705943, 0.881882820069748, 0.8794524991788546, 0.8752034662535086, 0.8740536265309745, 0.873471960703862, 0.8715883856010254, 0.8705394942947058, 0.8703711715843564, 0.869811086076592, 0.8688757332826906, 0.8687632599087841, 0.8675783893566328, 0.8658008886888324, 0.8648512488308956, 0.8643824664356563, 0.8640086450870792, 0.8639897525475265, 0.8621633047202735, 0.8612944218181178, 0.8611856288053185, 0.8609423655756817, 0.8603524225125737, 0.8564053341761181, 0.8563122749418219, 0.8562537506334607, 0.8561661061861644, 0.8554478122167476, 0.8549490902474624, 0.8536136747157799, 0.8532918071875812, 0.8529164394227774, 0.8525171664

[0.8644017434415426, 0.8584821304620902, 0.8568560330439459, 0.8561506294553315, 0.8505220647213075, 0.8504013476654458, 0.8499585616191829, 0.8488020277758508, 0.8458673538984417, 0.8431206236155322, 0.8417737312617415, 0.841293561013468, 0.8396605550081284, 0.8394892090724002, 0.8388889512153167, 0.8387343095866475, 0.8379641107300805, 0.8368498413163631, 0.8359424012520722, 0.8325842648083911, 0.8321343887583158, 0.8285819421800951, 0.8283888159574216, 0.8283264308609493, 0.8266607280615011, 0.8248742345597652, 0.8224393384172061, 0.8221613122217472, 0.8219484935928217, 0.8201451923509582, 0.8197983156494247, 0.8191096132763221, 0.8182992712487566, 0.8160255104187563, 0.8147671370933279, 0.813806446032904, 0.8099141993754215, 0.8098661916327765, 0.80945195326376, 0.8094262781066129, 0.8090486711408541, 0.8090281711239609, 0.8078402315816166, 0.8061897450721155, 0.8044957737885211, 0.8044924240665752, 0.8042295663416207, 0.8035498116068757, 0.8034721866551258, 0.8032402919867205, 0.8

 22%|██▏       | 22/100 [00:00<00:01, 54.76it/s][0.9016096037977931, 0.899103116582838, 0.8989424520077579, 0.8934356583765501, 0.8917674994289805, 0.8915983378596467, 0.8780767409940728, 0.874156963564923, 0.8738498599319823, 0.8708644694788898, 0.8679778686075891, 0.867177585711036, 0.8657507607981192, 0.8648226714422025, 0.8647081879422872, 0.8602447192170237, 0.8563573983651406, 0.8550126319489405, 0.8547774320552555, 0.8536809409290347, 0.8496109253510127, 0.8490301974535011, 0.8481539229164466, 0.8476636403451501, 0.8441576912808995, 0.8428442592228081, 0.8423104339085432, 0.8415256717433572, 0.8399843607452511, 0.8399519473526914, 0.8383704179544682, 0.8379100474787633, 0.8364223699330511, 0.8358317417112515, 0.835427386090673, 0.835056509194581, 0.834574035534864, 0.8334414420834574, 0.8332677935416871, 0.8330536305355183, 0.8326432234105715, 0.8326008461616065, 0.8325870211103894, 0.8325824347546618, 0.8323907263999027, 0.8317115131445947, 0.8315712629062864, 0.83147985587711

[0.8817823567591759, 0.8753490958264084, 0.8748771978453302, 0.8676830628594823, 0.8640241706287327, 0.8605058924678447, 0.8590244264653263, 0.8507128540942352, 0.8480309366368821, 0.8442565877942453, 0.843508612240516, 0.8417199128251946, 0.8401291131341144, 0.8392868269620364, 0.8389634343844901, 0.8383242206118336, 0.8361234311069724, 0.8338367475296878, 0.8338021355954011, 0.8336943898837724, 0.833500067741734, 0.8271837254133788, 0.8270381705066798, 0.8253248572418211, 0.825195884834434, 0.8231910456194238, 0.8210187315884877, 0.815870337430221, 0.8149706884667043, 0.8125305226060113, 0.8116909891413947, 0.8082632953391948, 0.8068052256374444, 0.8065209730081517, 0.8037860624305088, 0.8029398994943399, 0.8026318174335357, 0.8021879166372091, 0.8017282198803135, 0.7993057380708476, 0.7962134365132242, 0.7939824862670332, 0.7936772613512, 0.7914115658261867, 0.7901696687475774, 0.789981353060076, 0.7888959507624821, 0.7865174006095109, 0.7853599923990982, 0.7847480868668975, 0.78309

 32%|███▏      | 32/100 [00:00<00:01, 62.21it/s][0.8603166547333538, 0.8535440860362603, 0.8155186902931651, 0.8086754264515623, 0.8003012110533902, 0.798934081717391, 0.7921625871970065, 0.7867870085166062, 0.7763186841255847, 0.7727899490923855, 0.7673700698391746, 0.7660355874102104, 0.7626598901852344, 0.7520759932944938, 0.7504718679291744, 0.7468203689813644, 0.7444349221589652, 0.7375461806075351, 0.7354780299276825, 0.7314948149670806, 0.7252799198356707, 0.7107793356336114, 0.7038551925122628, 0.6994647910009802, 0.6986701191735558, 0.6948587814823204, 0.6944409933905229, 0.6944037943871946, 0.6928805208098288, 0.6925572542737347, 0.6918230593442458, 0.6827112911178171, 0.6815031361933852, 0.6800466880330436, 0.6619800467711603, 0.6615327382230133, 0.6484178012235076, 0.6473949005396692, 0.6473724462897826, 0.6429846034035817, 0.6296525037613886, 0.6236980371133602, 0.6227188198413691, 0.6206394132047418, 0.6200669249589027, 0.618014251274128, 0.6168976308785961, 0.6140509468

[nan, nan, nan, 0.7644985752183007, 0.7101800195783332, 0.6884132520276076, 0.6815510576920046, 0.6502702404432452, 0.6164511224299706, 0.6128946054118554, 0.6006540696557748, 0.5920951957870935, 0.589838130842284, 0.5861111925094986, 0.5842332112319659, 0.5837740482150688, 0.5752520712704385, 0.5697506733257716, 0.5678886231034355, 0.567443330638722, 0.5617673349175981, 0.5613038446539884, 0.5604126545440037, 0.5595419035874393, 0.553747634162543, 0.5480735054911732, 0.541650659633407, 0.5397089700651708, 0.5236703394730515, 0.5211598559129864, 0.5089392530094025, 0.5053191584750584, 0.47350329887220516, nan, nan, nan, 0.567735780853918, 0.5587304471466077, 0.5432177584261775, 0.5115430285211965, nan, 0.7444275842498155, nan, 0.717480794109244, 0.7151550933294663, 0.7113171234203183, 0.7040644996375671, 0.688942299431249, nan, nan, 0.699927697918317, 0.6999182860746539, 0.6892384436281028, 0.6720287753398272, 0.6695294130145701, 0.6671007063845004, 0.6661968060284059, 0.65203088692092

[0.9475684882727273, 0.9348928643801694, 0.9303901471986411, 0.9301474379407005, 0.9274578034791185, 0.9231887289417408, 0.9185732112720186, 0.9180300338241707, 0.9172127116508481, 0.9162180840122106, 0.9118863014359208, 0.911818644035976, 0.9116979314330805, 0.9087184360373193, 0.9086348565694582, 0.9079785660786198, 0.9033528272188266, 0.9032377549584378, 0.9030816828621301, 0.9025313535427754, 0.8964138556805731, 0.8943459080393893, 0.8933088163746535, 0.8906465000385742, 0.8880597123402205, 0.8844096156254712, 0.8828587934289278, 0.8810878488413275, 0.8749709381586211, 0.8744967049355361, 0.8706558560681834, 0.8636419881138365, 0.8636327230203985, 0.8488492872451254, 0.8476438303917203, 0.8407084973767823, 0.8404927358934169, 0.8394105527870692, 0.8378437202776011, 0.8336826114368968, 0.8281937664572281, 0.8248373721579743, 0.805368893846349, 0.7754551857090399, 0.7341318858110168, 0.7299256334118185, 0.7244585154311789, 0.7185359140724026, 0.7149375220168337, 0.7147585552624451, 0

[0.8412781894445053, 0.8140029411074093, 0.7960546261313103, 0.793277894149233, 0.7920725909958184, 0.7887008451859246, 0.7865041770741289, 0.7790087594527992, 0.7774886857444214, 0.7755670012175079, 0.7731568874982505, 0.7686851198698156, 0.7686502276354593, 0.7679975587856286, 0.762170078567209, 0.7615305527286005, 0.7601614657092627, 0.7477857766036121, 0.7477382874129147, 0.7471905391847216, 0.7463387121629689, 0.7436826869438417, 0.7433422362070283, 0.7420938580500959, 0.7412456923919374, 0.7391682268920401, 0.7305864669355768, 0.7272543688817431, 0.7256521978597867, 0.720506427320279, 0.7202864615387435, 0.7196364310768728, 0.7061084642482061, 0.7058296176018855, 0.6940217067107878, 0.69026275409477, 0.6796245430559352, 0.6774097799397658, 0.6512479799048596, 0.6474524433035135, 0.6452829394344357, 0.6258256281335275, 0.616930409804813, 0.5668213610769988, 0.4144287315941396, nan, 0.7520914789897823, 0.7243659746725302, 0.707822681096132, nan, 0.7152277422921184, 0.70808382007804

 50%|█████     | 50/100 [00:00<00:00, 65.88it/s][0.851952047575956, 0.8417148437928316, 0.8204531102462578, 0.8203675184782885, 0.8165511224684409, 0.8056341989437483, 0.8000798364126778, 0.7908174951802732, 0.7864445278077787, 0.7833506033056366, 0.7826597622426912, 0.780353241782825, 0.7794621625337612, 0.7787309314584391, 0.7738736265144293, 0.7725587618711114, 0.7702017515022805, 0.7697674107933851, 0.7690663653703761, 0.7685770025812096, 0.7657656317182634, 0.7631958212072262, 0.749387654386616, 0.7482141187443109, 0.7462236982094342, 0.740493467352568, 0.7380596268271123, 0.7376733224293739, 0.735632214338944, 0.7275195639072843, 0.7261207375851477, 0.7257657323476334, 0.7255217114400807, 0.7229438017733515, 0.7214865277574852, 0.7187662529307552, 0.7177397190251521, 0.7151248779222124, 0.7136167338587334, 0.7132003647249985, 0.7108460753301403, 0.7080549339089434, 0.706150698252567, 0.7020749003525191, 0.6991375670442492, 0.6962493587076273, 0.6946216785809863, 0.68684344535614

[0.8749691849101243, 0.8687773103345913, 0.8609221482749263, 0.8605938749491036, 0.8568645660953302, 0.8542857816654467, 0.8535175188650231, 0.8533999655314308, 0.851099944367383, 0.8502093304762876, 0.8468164281086001, 0.8423586205115305, 0.8417329193009846, 0.84122861280996, 0.8333372047694783, 0.8321735520147733, 0.8317816434093047, 0.8315993955970765, 0.8296030594613195, 0.8293222833540337, 0.8279940061792971, 0.827591176438426, 0.8273621534037143, 0.8273351377858348, 0.8252500438757606, 0.8246450686767438, 0.8207115183192161, 0.8201899084569155, 0.8193691172848918, 0.8192886469526633, 0.8190982938625668, 0.8177008168434705, 0.8154122133665077, 0.8150976582276827, 0.8144882794113818, 0.8141598599020989, 0.8138491141560563, 0.812311890580767, 0.8107698112361441, 0.8103919669365628, 0.8102239818741342, 0.8101831755754982, 0.8085386399155375, 0.8060476580558724, 0.8059826545249154, 0.8037771477889498, 0.8030055982995609, 0.8011711087660727, 0.8007629760071122, 0.7995229410374786, 0.79

[0.8742631056350829, 0.8546948070759747, 0.7829526944037578, nan, 0.8497331188542009, nan, 0.8286835137178197, 0.8233133120859132, 0.8131903229201524, 0.8007120853249463, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 0.8595093684682721, 0.8415038284128973, 0.8241451839053857, 0.7996887764347953, 0.7665033438162218, nan, 0.8411287402755951, nan, 0.8134584369579696, 0.7630495299580509, 0.7383694248480068, nan, nan, nan, 0.8317761166215717, 0.828244810859137, 0.825827820844266, 0.8123475450729649, 0.8104504889542069, 0.7895802004172308, 0.7809190331245965, 0.7796158470670654, 0.7761351434443741, 0.76219295019658, 0.7598608914165637, 0.7369306436975778, 0.7293542907733352, 0.7145629205691875, 0.6987301096046731, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 0.847688765527461, 0.8175340425352027, 0.8068770235125168, 0.7885889053430694, 0.7452762814952636, 0.730213848997723, 0.7296402346879516, 0.7225613127094586, 0.6894297925035031, 0.6855543699907405, 0.6783618420338553, 

[0.8907734079872442, 0.8867552534293472, 0.8616904205905643, 0.855931957050453, 0.855774473270064, 0.8487945410000483, 0.847979526458576, 0.8465941775434526, 0.8447066216244977, 0.8432150007738769, 0.8407621536409997, 0.8335381088165569, 0.8334911284891701, 0.832846842580062, 0.8297395707328977, 0.8293984769070436, 0.828566639492426, 0.8268744107032384, 0.8254410212181024, 0.8228694793472978, 0.8209697707310885, 0.8205811302399351, 0.8158397330896093, 0.8129215484564934, 0.81265649898175, 0.8125994332682227, 0.8117989379827663, 0.8114760922922253, 0.8105942188619149, 0.8059960029461292, 0.803298568290328, 0.8030795373260795, 0.8002561173745232, 0.7989434512105004, 0.7968943647232992, 0.796192518258142, 0.7960983702940897, 0.7938969586294138, 0.7921390394665407, 0.7879764209419561, 0.7875830660924977, 0.7868755548253974, 0.7864722571910563, 0.7859917454169367, 0.7858918705685708, 0.7844321954300842, 0.7828774481054055, 0.7802323532217771, 0.7797460119285612, 0.7767083389223174, 0.776157

 59%|█████▉    | 59/100 [00:00<00:00, 70.89it/s][0.8014408516898123, 0.7964242263648834, 0.7953886195550591, 0.7888482555024424, 0.7880750702370449, 0.7822936295041091, 0.7778577258173842, 0.7737589599679817, 0.773700434545928, 0.770688668896221, 0.7692309973788017, 0.7686452081845125, 0.7683678827121786, 0.768366505286892, 0.7638597410677324, 0.7604253764844463, 0.7601926125592191, 0.7585442445962474, 0.7582166213702042, 0.7497898380995847, 0.7410584418869244, 0.7379225455943703, 0.7361538448775448, 0.7232726613819672, 0.7206705877825125, 0.7178279264878737, 0.7133032126796998, 0.7115331841089007, 0.6988468484841032, 0.6974659877921764, 0.6967961082551425, 0.688314490709421, 0.6872513012107014, 0.6869018779811507, 0.6828717930778998, 0.6803434952941768, 0.6800145638684346, 0.676019936985032, 0.6736291701109591, 0.6717402014392327, 0.6685834746872612, 0.6684047937589329, 0.6647011680171843, 0.664659367704145, 0.6607850182840362, 0.6542614046096648, 0.6529613272002456, 0.64568036837793

[0.8566298899480422, 0.838312704100897, 0.8308498000070027, 0.8180470602315757, 0.8123563741067303, 0.8118830673566879, 0.8056556416032545, 0.8009116354040842, 0.7993358942341493, 0.7952046938601233, 0.7874504220019591, 0.776473976870866, 0.7744048425380838, 0.7733471995132948, 0.7685959209447363, 0.7670360429323175, 0.7666454202365846, 0.7664584002513555, 0.7663460540462506, 0.7661867631811119, 0.7607608521428938, 0.760007237728417, 0.75911120898366, 0.7562270636805364, 0.755358078562723, 0.7542815775307947, 0.7538685146651936, 0.7532862358416572, 0.7520110092112557, 0.7513545183570288, 0.7511406270783361, 0.7491359137007809, 0.7489869064577881, 0.7482422450090267, 0.743331925776738, 0.7408160024431553, 0.7408138121080259, 0.7323315522289402, 0.7317615851667941, 0.7313346800954675, 0.7296567317592121, 0.7292388908693642, 0.7267754837291446, 0.725146876389331, 0.7210277012770999, 0.7100493879262297, 0.7096580395536496, 0.7068747196623039, 0.7067642886304428, 0.7062456800260906, 0.70411

 68%|██████▊   | 68/100 [00:00<00:00, 75.08it/s][nan, 0.8254771106930299, 0.821354679764508, 0.8092284156203468, 0.8077109601809427, 0.8041571152905312, 0.8020898339346565, 0.8011590248015118, 0.7992603842829358, 0.7991656106497154, 0.7952174620676876, 0.7718014553950708, 0.7660208592931896, 0.7526391622662754, 0.7512355439689483, 0.7211391263414069, 0.7071664058360939, 0.7051697477818974, 0.6982015477985647, 0.6946446742721734, 0.6928669976099701, 0.6908913408374802, 0.6704004218962976, 0.6642970285724767, 0.6547793563896932, 0.6378394632843348, 0.6285025492697006, 0.6276897769310954, 0.6256352519736895, 0.621483959926456, 0.6078824477546338, 0.6028809571379008, 0.5972514676258668, 0.5908484110509226, 0.5900087717867397, 0.588479674561507, 0.5870852391282352, 0.5807544576206236, 0.5774904404676918, 0.5774681403670859, 0.5730510817676475, 0.5688039143435978, 0.5499154491297398, 0.5469454121470434, 0.5349386608203327, 0.5269209350355343, 0.5237675853051835, 0.5190495068754531, 0.501631

[0.849649335008998, 0.8335189985630405, 0.8201344690529271, 0.8199320137667748, 0.817529903515188, 0.8085572790066797, 0.8003774601306539, 0.799549133020053, 0.7994678684355861, 0.7975826175719453, 0.7971441137970178, 0.7961673198583424, 0.792911187489902, 0.7881541465126681, 0.7858141353996193, 0.7811513602357159, 0.778224543512182, 0.7778693879823099, 0.7754987000208707, 0.7728774983174886, 0.7678896160666682, 0.7672967588027446, 0.7664242671132517, 0.7595113258614473, 0.7593441805049843, 0.7542562839535999, 0.7515577594559533, 0.7515308358700806, 0.748771727243421, 0.7317660066290229, 0.7278857892532604, 0.7275346911674963, 0.7273406730778593, 0.7257709806975422, 0.7253053189105356, 0.7236587184402662, 0.7217669332818967, 0.7142402204711874, 0.71301078527611, 0.7089730676679624, 0.7066393813408137, 0.7058329240712777, 0.705604975546718, 0.7044242819270289, 0.7011756595419234, 0.7004921878209541, 0.6991315590589283, 0.6972234776076229, 0.6960018718234701, 0.6946678154716257, 0.693683

[0.881748959372514, 0.8588985806823606, 0.8452456263391868, 0.8442996585675976, 0.8432531757833244, 0.8349116683952894, 0.8344221199797679, 0.8289949516096529, 0.8287269298131053, 0.8258511869847219, 0.8254585304001126, 0.8247488394114083, 0.8222262322095328, 0.8219670671281112, 0.8201832146930155, 0.8190728703850778, 0.8188999916467391, 0.8187486361140311, 0.8187231282003307, 0.8175784806656029, 0.8126444722902284, 0.8116961149938186, 0.8115631308519353, 0.8084057846311334, 0.8067993902879476, 0.8061736889879734, 0.8058902567440254, 0.805794560625079, 0.8039795290797626, 0.8027503488803764, 0.8026751137566602, 0.8017561605745152, 0.8006462291351535, 0.8006175612816261, 0.8005717055594044, 0.7991983207220017, 0.7969246753625475, 0.795204966327358, 0.794903124802065, 0.794526048075366, 0.7940038518513669, 0.7938959003047084, 0.7931693179547311, 0.79211048426736, 0.7911502590009527, 0.790305067161994, 0.7902360758839857, 0.7899257489603324, 0.7899016923523424, 0.7885349562491584, 0.78752

[nan, 0.8976952777206807, 0.8957917864961665, 0.8791067025961814, 0.8788603656501197, 0.8212177040690132, 0.815084446440597, 0.8001107445223776, 0.5867664061106147, nan, 0.8000954201153792, 0.7846480631793976, 0.7597070937876261, 0.5382123215552337, 0.5269034350706843, 0.526445962557127, 0.5123069772523984, 0.5065972668417961, 0.49762103211675174, nan, 0.717741776370483, 0.7041477727472144, 0.6304769984131731, 0.6073820792487212, 0.5103593191963364, 0.499671196520455, 0.4968485927024838, 0.4964786461175481, 0.49532168289716483, 0.47692827575270136, 0.4767029316569526, 0.4526482618229137, 0.44355870947884135, 0.43751728471531154, 0.4275494969929836, 0.38604206744974845, 0.36207894644933813, 0.35089119981179695, 0.2897656137882268, nan, 0.8758236523385395, 0.8465494023742781, nan, 0.8673677031012973, 0.8509080724676003, 0.8210065625434919, 0.8076284218686541, 0.8040247167041091, 0.8029290673734931, 0.7936411464332577, 0.793363671147743, 0.7827865198731045, 0.7818814306481234, 0.773809419

 78%|███████▊  | 78/100 [00:00<00:00, 79.73it/s][0.8351085386596011, 0.8272090709831085, 0.8241220571174153, 0.8169727387505509, 0.812625671887643, 0.8014826215481767, 0.8012216506525119, 0.7972200556892891, 0.7952874884068417, 0.7924040755141085, 0.7904843775688916, 0.7901566008294323, 0.7894929256402858, 0.7864725597868858, 0.7857432801449504, 0.7786288616526506, 0.7757991138003693, 0.7729693662620571, 0.7725575309410903, 0.7719266936817166, 0.7714727793489112, 0.7666398129132627, 0.7658581181467986, 0.7597711080541255, 0.7585633564253058, 0.7585289804338071, 0.7565544484052048, 0.755651744086753, 0.7548286667418627, 0.7487401835327362, 0.7481514099616775, 0.747740214351106, nan, nan, nan, 0.838932113705757, 0.8067962603615908, 0.7990965995268308, 0.7906140286232848, 0.7845732070531332, 0.7839037337006655, 0.7783388165829219, 0.7779045442224288, 0.7696068690632822, 0.7669126506446953, 0.7591041436624352, 0.7565071446491649, 0.7485774354855246, 0.7472980163789689, 0.7472741251761792,

[0.8894612069932671, 0.8869510905084582, 0.8703052485505347, 0.857462558695573, nan, 0.8555693567076744, 0.8552267539194912, 0.851769633557596, 0.8502826670004664, 0.8498672428522653, 0.8435713168452327, 0.8421515867945443, 0.8345994232329965, 0.8334934745970005, 0.8307685259600053, 0.8266771685800413, 0.7983783891006562, 0.7963086399539223, 0.7887882895202555, 0.7859792290696622, 0.7853562424166727, 0.7837114293545023, 0.780426386943118, 0.7796968067768217, 0.7742687302417063, 0.7699721565920594, 0.7642549738308068, 0.7637407153006848, nan, 0.8113889936434436, 0.7622367277569577, 0.7611940498719945, 0.7562946514867682, 0.7497480637712356, 0.7473459443693424, 0.74515454101968, 0.7444623139362864, 0.7443719163259985, 0.7408080731357679, 0.7362888941369858, 0.7353037920806007, 0.7330034843339254, 0.7305206805345438, 0.7302311640536626, 0.7277821221006882, 0.7271707188069593, 0.7218374600779079, 0.7198548577623708, 0.7135194275476837, 0.7103939439123484, 0.7052388320876963, 0.703763669601

 88%|████████▊ | 88/100 [00:01<00:00, 83.96it/s][0.8154432870478431, 0.7775515737573454, 0.7489295045772528, 0.7469642183162798, 0.741872243878289, 0.7329598815119507, 0.729612381498582, 0.7053166572342179, 0.696393605718852, 0.6956736406839303, 0.6505525881696209, 0.6281459036253277]
[0.8368834484661927, 0.834977165787353, 0.825592235903564, 0.8189056828700966, 0.8187923777148908, 0.8130242779650851, 0.8108273919654205, 0.8103070253984287, 0.8099990800908848, 0.8096407457074422, 0.8091444600171119, 0.8090542491491345, 0.8082649545585221, 0.8067303750803038, 0.8057604169706075, 0.8032549687720055, 0.8011858106610835, 0.7999140321221334, 0.7977244702415573, 0.7970671049253033, 0.7966439832925756, 0.7936952147962979, 0.792355558249566, 0.7916942458290592, 0.7912013089878759, 0.7904547741296639, 0.7903669805000743, 0.788379091195902, 0.7877121808231047, 0.7876241773347123, 0.7876121048213547, 0.7859876846766078, 0.7858785430170588, 0.7844201148584966, 0.7811399732361491, 0.78038619695535

[0.9311744152192173, 0.9134475773070263, 0.8970218313335295, 0.8936648878214074, 0.8908117332290233, 0.8865073917795909, 0.8764322898200796, 0.8748847149126356, 0.8739176217694157, 0.8695817908115168, 0.8691863696967437, 0.8685913225592216, 0.8678803089386165, 0.8656437073454804, 0.8656026274217087, 0.8647532300133346, 0.8647020044323825, 0.8638715648520975, 0.8633510510235743, 0.8627084262464106, 0.8624785603551612, 0.8624022423987713, 0.860796969324513, 0.8598805543262594, 0.8535759549612544, 0.851998973229728, 0.8494198136685327, 0.8456406003860184, 0.8439456684688591, 0.8418090199889087, 0.841229643955316, 0.8392807673338761, 0.8342823236407808, 0.8342585901510088, 0.8320405658781056, 0.8292760875211882, 0.8277128976905025, 0.8238562209239095, 0.8236836022119189, 0.8195986018611481, 0.8181036914302316, 0.8174622050912901, 0.8154581921254743, 0.8127335601622288, 0.805535063193944, 0.7972508618941915, 0.7951130020422568, 0.7946371319778623, 0.7946290815639945, 0.791048298465933, 0.78

[0.8582752603518758, 0.8224916572287002, 0.8203805672654892, 0.8196126968966873, 0.8194845780939537, 0.8184850054222778, 0.8175928787544071, 0.8161638626301816, 0.8152985315048122, 0.8150196043953676, 0.8145488935424111, 0.8141058582575238, 0.8128948946877937, 0.8126758321635561, 0.8118183093525664, 0.8107871416526934, 0.8019035898723955, 0.8000795121596774, 0.8000512965955636, 0.797923078816562, 0.7942821145878447, 0.7893430725889553, 0.7832689693535408, 0.7815273137069104, 0.7800603285742662, 0.7792426683663337, 0.7782398853708077, 0.7772577826494957, 0.776977346240097, 0.773297622682094, 0.7718040672686026, 0.7715113990505869, 0.7696153644145232, 0.7695700700839495, 0.7687927091899406, 0.7683415456607604, 0.7656193774789847, 0.7650404076860358, 0.7612249215531387, 0.7594673394913967, 0.7580228319290994, 0.755043841339365, 0.7535626271908552, 0.75316266807235, 0.751299882954583, 0.751230613891708, 0.7503469823691562, 0.7500585474143497, 0.7431928590501617, 0.741081784316504, 0.734752

[0.9367153756498957, 0.934772927195238, 0.9325877155908866, 0.9288704069565016, 0.9233840359005241, 0.9211960682405647, 0.9203378069702955, 0.9195661579120639, 0.917766318493343, 0.9151770421185018, 0.9148466321329501, 0.9146516757686382, 0.914620277071798, 0.9142856204955979, 0.9133855219225282, 0.9125935770878993, 0.9125686314883766, 0.9116443337157337, 0.9114133309181359, 0.9109075756621573, 0.9096768769037304, 0.9089705001675126, 0.9089370950066229, 0.9088555081688604, 0.9076507633102251, 0.9074972888287569, 0.9070540796441068, 0.9070084706547735, 0.9057978428845278, 0.9050651528143986, 0.90475219217315, 0.9045652304269647, 0.9004649744349198, 0.8993849358150845, 0.8987044533659501, 0.8986608676234962, 0.8980461131820265, 0.897857393429826, 0.8976514823846418, 0.8971597892846933, 0.8968202252017837, 0.896010926435072, 0.8939120779902925, 0.8938877505927909, 0.8935600985024458, 0.8934094173937303, 0.8928923190406759, 0.892384427020994, 0.8910105615174445, 0.890806443235996, 0.890371

 99%|█████████▉| 99/100 [00:01<00:00, 89.32it/s][0.7843566293198136, 0.7724002507123898, 0.7676408830108097, 0.748886311423223, 0.7418867267887965, 0.7351914686841801, 0.7313154386647791, 0.7279952753987481, 0.7164221141104069, 0.7126908323023745, 0.7009394926061773, 0.6985348473249141, 0.6929204504844226, 0.6926981731730655, 0.6880754740517068, 0.6842119745536241, 0.6811850523998347, 0.6766205305578612, 0.676274847115059, 0.6756841400022167, 0.6755116176562588, 0.6746697676139084, 0.6705334186766809, 0.6649179156808985, 0.6623125123610356, 0.6601078418830909, 0.6594149209426489, 0.6508456173890499, 0.6499947743068374, 0.6482972265580405, 0.6481267203296264, 0.6430465992867211, 0.6415301874770206, 0.6397601715156429, 0.6389868996490076, 0.6383263614131884, 0.6369067440465638, 0.6344344617373603, 0.6338557818295479, 0.6337217101078, 0.6332091007038594, 0.632625446114532, 0.6307491640074628, 0.6306305460842057, 0.6300242597151582, 0.6289232822370813, 0.6261221351731379, 0.623342733338028

100%|██████████| 100/100 [00:01<00:00, 83.05it/s]
ModelEvaluation(mapk=0.7394630952380952)


In [38]:
evaluations = k_fold_cross_validation(
  train_df,
  k=CROSS_VALIDATION_K,
  train_model_fn=train_word2vec,
  preprocess_val_fn=None,
  build_ranking_fn=lambda word2vec: lambda user_id, songs_df: word2vec_rank_by_user(word2vec.wv, user_id, songs_df))
mapks = list(map(lambda evaluation: evaluation.mapk, evaluations))
np.mean(mapks)

****************************************************************************************************
Taking split number 0 as the validation set...
Train indexes: 5901934
Val indexes: 1475484
Users: 30548
Building sessions
Building vocabulary
Training word2vec
Word2Vec(vocab=88866, size=32, alpha=0.03)
Building user embeddings
5901934it [05:22, 18307.65it/s]
Evaluating the model
  0%|          | 0/100 [00:00<?, ?it/s][nan, nan, 0.7880005389308355, 0.7101731311951163, nan, nan, nan, nan, 0.7748911874260385, 0.7479348832680403, 0.7448554576789017, nan, 0.7472644274770357, 0.7413766472868255, 0.6703330586545072, 0.6559112264128524, 0.6323498022860515, 0.6188911196773353, 0.6166856151001509, 0.6100725705763418, 0.6098558428744367, nan, nan, 0.6362692287818092, 0.6309063742526874, 0.6000691412550989, 0.5998039358352925, 0.5900736969185105, 0.5615207931363457, 0.5556093051048546, 0.5442807337565491, 0.5439947761619812, 0.5429990418947019, 0.539857906959352, 0.5318945453780919, 0.510189699790

 24%|██▍       | 24/100 [00:00<00:00, 230.88it/s][0.8453213973222007, nan, 0.8502536772324023, 0.8276510051541799, 0.8124952704771639, 0.7330068474159356, nan, 0.7956572301822575, 0.7897052455945858, 0.785900728648725, 0.785820138186086, 0.7658809855294935, nan, 0.7773287380663404, 0.7669200418637753, 0.7463668503551856, 0.7456346362276328, 0.7377801047385141, 0.7297357123813183, 0.7231038636526314, 0.7125520740627298, 0.7025689862072023, 0.6975917797359881, 0.6893874650040588, 0.6834129906609983, 0.6722686973998063, 0.6677533883418956, 0.6669118745109355, 0.662752451762519, 0.6594756186686698, 0.6576136441850455, 0.6550089113137881, 0.6533872339734997, 0.6524070433550161, 0.6340558730484478, 0.6321532883295327, 0.6276634253662858, 0.6187437880026727, 0.6182491910008747, 0.6117934912354741, 0.5983328488553125, 0.5960805549379208, 0.5937054759047137, 0.5908508317885073, 0.5899801941993595, 0.5645525615327907, 0.5584085136064738, 0.5578448794282899, 0.5537054606033418, 0.5523509767736453

[0.9076116153433352, 0.906428161741543, 0.8981238493209486, 0.8939239596955125, 0.8824748049159493, 0.8787618350251006, 0.8750389052542447, 0.8616971205236442, 0.8604694927030331, 0.8598346941625159, 0.8564657971403105, 0.8476858859111216, 0.8341541648463081, 0.8265350295826035, 0.8247413322390683, 0.8176221411637604, 0.8154901669937724, 0.8154508897271139, 0.8154034796686493, 0.8128675863949552, 0.8125161068088369, 0.8114407775118592, 0.8104463557582171, 0.8103868824475418, 0.7978684707813062, 0.794440246288384, 0.7938772866046556, 0.793566204174808, 0.7907746222056466, 0.7852174436123456, 0.7844396823023765, 0.7841123276383594, 0.7805587302060346, 0.7801695768591552, 0.7796554959645722, 0.77959690865507, 0.7724875147072335, 0.768873474479383, 0.768555602365889, 0.7673327854143187, 0.7636843967336586, 0.7615892992102795, 0.7595412053103822, 0.753633219527769, 0.7529666330030035, 0.7475168415519867, 0.7398396743869022, 0.7380057920778145, 0.7361159724527863, 0.7348705776636871, 0.73417

[0.8598957444607794, nan, nan, 0.8518932299571412, 0.8496718111376714, 0.8492604394374708, 0.8489218455332282, 0.8454014184135502, 0.8448682956428022, 0.8283772564041362, 0.8270987712897498, 0.8211809368972315, 0.8210606904713409, 0.8157709423305148, 0.7670068630065758, 0.7615852420253714, 0.7582703724574268, 0.7566174571475044, 0.7479081186619484, 0.7421373908716593, 0.7362807627639746, 0.7309695586883707, 0.7259796082102226, 0.7240255294771981, 0.7215691458387705, 0.7040955883405744, 0.6787685143457133, 0.6740702978790636, 0.673831954938764, 0.6736827332628798, 0.6668265562310691, 0.665418108799079, 0.6646907934254159, 0.6637599306371168, 0.6543020280407209, 0.6504817169919893, 0.6405992892247113, 0.6356498717873944, 0.593731655176638, 0.5934635883011329, 0.5890978377248498, 0.5720727172517657, 0.5358469755454143, 0.533734435308102, 0.5085010086947352, 0.4808668296764208, 0.45507545626050405, 0.4054836619218889, nan, 0.8437485849258523, 0.8034445870207323, 0.7995345665884402, 0.79653

 74%|███████▍  | 74/100 [00:00<00:00, 238.29it/s][0.8493581066779426, 0.8398373791305789, 0.8285484500680954, 0.7284456839356662, 0.7268019222169474, 0.7243716208686191, 0.6990408088947382, 0.6894067988666509, 0.6796065740570281, 0.6756958019338168, 0.673084767570446, 0.6714986872181427, 0.6675430672475888, 0.6626178894500835, 0.6474051931177781, 0.6327164112077679, 0.6082115666615794, 0.604539031769092, 0.594095165438788, 0.5290124946607184, 0.5152476731338869, 0.4982494233925669, 0.4673920342986929]
[0.840575373472402, 0.8267000438336212, 0.8178784029761673, 0.8160134541442002, 0.8149542672097733, 0.8124185280197831, 0.8104045831068876, nan, nan, 0.8151213262040767, 0.8095319006186821, 0.8081480855851026, 0.8049715223901116, 0.8044876992134707, 0.8044109997343801, 0.8026116822382966, 0.8007765413549977, 0.8005870552338628, 0.7902764438371446, 0.7886915790559741, 0.7854356967857028, 0.7837998000995652, 0.7828685264995887, 0.7793044885935495, 0.7788511270174581, 0.776392965313994, 0.77

[0.81421718343032, 0.7298926533659531, 0.7214073829825168, nan, nan, nan, nan, nan, 0.8125080332260596, 0.7666118746864391, 0.7262220370152588, 0.7188863678686659, 0.6799108545088846, 0.6677037832445493, 0.66444457146927, nan, 0.792621614148165, 0.786026895860506, 0.6991327863169485, 0.6709143848978824, 0.66221868823462, 0.6614320604768054, 0.6613359283050251, 0.6611455537828594, 0.6502479786003877, 0.6361881830731984, 0.6310020505221859, 0.6260682638466352, 0.622951831064889, 0.6172821587715063, 0.6132789122922179, 0.6095211860090506, 0.5961671616159577, 0.5789488522677287, 0.5681749311584412, 0.5664378792231712, 0.5660812450402406, 0.5416544044317434, 0.5379469110410035, 0.527033669117055, 0.5263013369058018, 0.5171822855116559, 0.5171569926885451, 0.5158040198749838, 0.49418435183382037, 0.4644759291419456, 0.41970091275767374, nan, nan, nan, nan, 0.7074734433994667, 0.7065253129793784, nan, 0.8045718434986302, nan, 0.8071914557849572, 0.7933078293996576, 0.7600260886645371, 0.72286

[0.8946471965155885, 0.8481456942592576, 0.836979183702213, 0.8342571458187438, 0.8322941785898961, 0.8245175277363452, 0.8217654488202547, 0.8162838374831657, 0.8153824655888267, 0.8103786650027999, 0.8042188142081356, 0.7966134380596616, 0.7884291767653894, 0.7864695925867917, 0.7851582656661764, 0.7827162992310157, 0.7788570967769785, 0.7742707413369714, 0.7672506745626758, 0.7565949746592522, 0.7544804227899449, 0.7516731548277462, 0.7453336691134708, 0.7451255223320626, 0.7424388084550889, 0.7350675694891423, 0.734585895358164, 0.7313545690001262, 0.7227263742269665, 0.717096206661446, 0.7147492707648316, 0.7123700394790397, 0.7111557010890394, 0.7105630110161494, 0.7081900718594804, 0.707416948511039, 0.7070017173359348, 0.7045754380998248, 0.7029572253494728, 0.6986866130086433, 0.6929068362112016, 0.6900934746289238, 0.6842469892648774, 0.6729584990291998, 0.6727018374840379, 0.6723855252013171, 0.6716579239239052, 0.6621771587355657, 0.6577113639520827, 0.6535759892192214, 0.6

 28%|██▊       | 28/100 [00:00<00:00, 276.79it/s][nan, nan]
[nan, 0.7909670424372071, 0.7818261691905968, 0.7754104996733596, 0.7750066157070682, 0.7535844071059662, 0.7511000090844224, 0.7212714482820749, 0.7114360665383908, 0.6751926338466221, nan, 0.7195967738229866, 0.6405418228609577, 0.638214867240253, 0.6377692520432547, 0.6153619675994929, nan, 0.7601361878462942, nan, 0.775112180838256, 0.7589856277713269, 0.7527940642238691, 0.732675030315506, 0.697635726029633, 0.6906989843205088, 0.6581268556750101, 0.6406246283673777, 0.6214096198791496, 0.6150820491459504, 0.6144860474113837, 0.6123362558532286, 0.609589837265144, 0.6040247368052883, 0.6008146435739421, 0.5991202052153121, 0.5807855443241386, 0.5744397459149677, 0.5673278934646802, 0.5649061167219124, 0.5634775270359396, 0.5573989908263618, 0.5564587715799396, 0.5563195854281223, 0.5556901519698544, 0.5502808912961562, 0.5388312002892587, 0.5313434424920798, 0.5184592947302574, 0.5095723028484606, 0.5095069662123206, 0.50

[0.8513743409400902, 0.8455869076798046, 0.806058197369321, 0.7980324935343639, 0.7979543663193671, 0.7958928946946165, 0.7850078695601651, 0.7705619480960536, 0.7679553027390227, 0.7551715720117099, 0.7475551855895483, 0.7446417238138896, 0.7396016287591509, 0.7346376564807056, 0.7269277330319351, 0.722136789805531, 0.7197357252397678, 0.7172182685042021, 0.715729682978558, 0.7077122021318223, 0.6936781659064084, 0.6907260176603216, 0.6904958572007485, 0.6886642790307166, 0.6864377488887577, 0.684167104555068, 0.6813496262775841, 0.6716163960366506, 0.661143365172552, 0.6039514786395209, 0.5825629347520949, 0.5495100898717682, 0.5251043577434128, 0.5048657935984759, 0.4459502437706424, 0.3955808580875259, 0.38767022392253864, 0.38070527029751344, 0.3800145858626147, 0.3626161922799887, 0.2966359084689784, nan, 0.812455476619323, 0.7976761237456332, 0.794423221086844, 0.791236266454664, 0.7840996211944581, 0.7679726446940369, 0.7679593654054057, 0.7650680812522219, 0.7630704395899018, 

 56%|█████▌    | 56/100 [00:00<00:00, 277.44it/s][nan, nan, nan, nan, nan, 0.7985589056544181, 0.7726316367400105, 0.744533377611211, 0.72302746861748, 0.6477761085275368, nan, nan, nan, 0.7297616298162974, 0.7090482288316916, 0.6945061713846165, 0.6766844928254866, 0.668637835123593, 0.650675122789695, 0.6201119694124132, 0.6137746390060869, 0.6086873785734834, 0.6052072209173189, 0.601668518842477, 0.599287922378471, 0.5665711175492593, 0.5545331137683617, 0.5337134147601398, 0.5252781659642743, 0.5218456266033671, 0.5181198214836739, 0.5039254667133051, 0.49567852901153747, 0.4931646523977967, 0.49080019711973, 0.4828565002198969, 0.44276717782113695, 0.36785637142230154, nan, nan, nan, nan, nan, 0.628884852078192, nan, nan, nan, 0.7321512821683722, 0.7105646872339945, 0.618618756217956, nan, nan, 0.750877370876875, 0.7042338060261989, 0.6968365849314817, 0.686762908462822, 0.6440921200523523, 0.6134735032328494, 0.6102222972561698, 0.6077601365769492, 0.600697184518219, 0.600429462

 82%|████████▏ | 82/100 [00:00<00:00, 270.60it/s][0.9235160972478919, 0.9229638547343861, 0.8993076040038348, 0.8841197444285885, 0.8746819885861257, 0.8690558795597485, 0.8529905282703294, 0.8426004316426049, 0.8325344914963256, 0.8317082269803401, 0.8221656766043628, 0.812624972702172, 0.8113285046914961, 0.7917684665799766, 0.7894535251150449, 0.7619740805428521, 0.7126092868780723, 0.6868068274079606, 0.6539775465864611, 0.649682727059224, 0.649617953350142, 0.6486763588205248, 0.6441599507104916, 0.624206626957044, 0.6209889906031465, 0.606369234712098, 0.6036794539692878, 0.6016914652499821, 0.5981484072135415, 0.5800075087112103, 0.5717197156102716, 0.5689199205535305, 0.5632985016532059, 0.560892865896156, 0.5602864353709834, 0.5584287097243581, 0.5320690707498201, 0.531716715358023, 0.4741523357674167]
[0.8617874279945588, 0.8551768031895202, 0.8450119430907361, nan, 0.8350471079532236, 0.8144341331963834, 0.8052374314528946, 0.7969062343081335, 0.7882671485432706, 0.775163858

[0.780102738737878, nan, nan, 0.7647324774520685, 0.7438508652004568, 0.7414063232391518, 0.7404049290518189, 0.7333483745573613, 0.731863059516649, 0.7288023639945759, 0.7271860884452211, 0.7215215300414989, 0.7188631245201149, 0.7154101328527179, 0.701358884762533, 0.6992039205464123, 0.6915684127544459, 0.6867119417001883, 0.6786040109781648, 0.6766792084302405, 0.674899853020567, 0.6702270255982754, 0.6658079896241118, 0.6594415859205591, 0.6582964551480658, 0.6571815998856133, 0.6536335712954243, 0.6440809628518234, 0.6412144551907011, 0.6382748055135316, 0.6250872230580486, 0.625041060171187, 0.6152013958003426, 0.6149387419944863, 0.6105597297285756, 0.6047409569713686, 0.6002526367858485, 0.5975077257150911, 0.5970460095577672, 0.5919848611743693, 0.591377589960995, 0.5848654633002816, 0.5839965072176846, 0.5823171451831181, 0.5812273866775459, 0.5807874223348287, 0.5753036211721039, 0.5750249359796344, 0.5671290938112492, 0.5518578003631419, 0.5417871012665781, 0.5407688178023

[0.8510132915608196, 0.8200132434482394, 0.792475551879261, 0.7916639912609733, 0.7872766793433457, 0.7854342724017377, 0.7846032493972462, 0.7843733808794402, 0.7781864673670271, 0.7620479128655124, 0.758687895897011, 0.7450281847249522, 0.7374540870266075, 0.7349576554059414, 0.7349128056212201, 0.7181283046525019, 0.7160388695236056, 0.6972506150445908, 0.6893761383558201, 0.6892378391800134, 0.6861565585665029, 0.672075200566027, 0.6678868917444735, 0.6611233659908624, 0.6597733235324876, 0.6567308029931285, 0.6562115271204106, 0.6495249874104011, 0.6490029961284132, 0.6467369625679849, 0.646553933105848, 0.6341024596481989, 0.6307231061105021, 0.6291667516712705, 0.6282760136052086, 0.628212317537798, 0.6258468209805461, 0.6235172860530378, 0.6200415001943151, 0.6171722938609099, 0.6086662253848635, 0.6047500332242931, 0.6044097840889658, 0.6041930753895439, 0.6039178409363798, 0.5951640691119221, 0.5935691011278401, 0.5925603659483443, 0.5888248122659497, 0.5873479082672448, 0.58

100%|██████████| 100/100 [00:00<00:00, 269.03it/s]
ModelEvaluation(mapk=0.49792596371882086)
****************************************************************************************************
Taking split number 2 as the validation set...
Train indexes: 5901934
Val indexes: 1475484
Users: 30516
Building sessions
Building vocabulary
Training word2vec
Word2Vec(vocab=89014, size=32, alpha=0.03)
Building user embeddings
5901934it [05:22, 18274.64it/s]
Evaluating the model
  0%|          | 0/100 [00:00<?, ?it/s][nan, 0.7906005224745616, 0.7796023767379834, 0.7783328149203057, 0.7568296429744623, 0.751695319566666, 0.7486250149703336, 0.731166279498983, 0.7269107800514361, 0.6848890660936691, 0.6779619196475122, nan, 0.7598797050411761, 0.7221806836421353, 0.7083948700157264, 0.6684824833355993, 0.6339570787600072, 0.6329770936410954, 0.6268069685023367, 0.6244366911085555, 0.6197374249025941, 0.6101576220356589, 0.5974241641485478, 0.5677856605993722, 0.5342722611029471, 0.534158806017043

 25%|██▌       | 25/100 [00:00<00:00, 249.83it/s][nan, nan, nan, 0.8382990343643051, nan, 0.7996251397939532, nan, 0.7653258228507136, 0.7179641414593712, 0.7066104416461679, 0.6971507738605435, 0.6904827004310715, 0.6787983672570707, 0.6772234022510177, 0.6239719798647979, nan, nan, 0.6867136975850106, 0.6380216260392116, 0.6206885101650775, 0.615067613404656, 0.6144741423282302, 0.6126332505243886, 0.6120349628841364, 0.6012636902463006, 0.5961273885377623, 0.5948917588495648, 0.5941115699430661, 0.589366673061129, 0.5823010130422365, 0.576622707600759, 0.5469706387855243, 0.5439380076106354, 0.4775618086890167, 0.4675466947023989, 0.4650067739827896, 0.4406722541001853, 0.404972004306936, 0.39676907346115314, 0.3743926656093702, nan, 0.787743767535068, 0.7788021634245306, 0.7606470460438441, 0.7477251162968138, 0.5923281936396889, nan, 0.7583161697941874, 0.7539751214386178, 0.729165442801644, 0.7149048477540385, 0.697974922594298, 0.6860300230797581, 0.6821020326013696, 0.674216270

[0.7606551472107729, 0.758769135311463, nan, nan, 0.7718350903800859, 0.7578475081180773, nan, nan, 0.7713303098913336, 0.753015081026124, 0.752996147648348, 0.7257105438925916, 0.7151148467321825, 0.7140029964375799, 0.7072864437887334, 0.707009236302275, 0.7063633852517096, 0.704375783923203, 0.703156912231946, nan, 0.7007472067378302, 0.6881030503094399, 0.6874894093434388, 0.6837646058809206, 0.6818172529032202, 0.6723331763720943, 0.6683118750115339, 0.6607307203071133, 0.6588449872267951, 0.6508676298926055, 0.6460985897725114, 0.6346665797028119, 0.6345524263689286, 0.6125553803004095, 0.608362029778653, 0.6048550279633707, 0.6032489792360827, 0.5962954686653765, 0.591628621478706, 0.5882906857234109, 0.5738941445112172, 0.5360706638414914, 0.49575797387124204, 0.37458323066224325, 0.3287589954856646, nan, nan, nan, 0.7603465135733751, 0.7442048089350289, 0.743372959298597, 0.7384140339838177, 0.7323995581426935, 0.7277470773878659, 0.7212425491512044, 0.7201511383938826, 0.7194

[0.8005815499889338, 0.7725573999300974, nan, 0.7539672287847055, nan, nan, nan, nan, nan, 0.7736609670372393, 0.7565848636533978, 0.7315395835754869, 0.729076226080329, 0.7170808553507633, 0.7077279900336912, 0.6974585272397644, 0.6960887366720945, 0.6877641145162285, 0.675264357919767, 0.6594285018921131, 0.6581508250556423, 0.6485616223378253, 0.6371667751774702, 0.6300593571699592, 0.6183419853042709, 0.60423714573619, 0.5980054827876585, 0.596718525860006, 0.5868345939331109, 0.5851481270793581, 0.5474957590245138, 0.5469867344283215, 0.537303639179341, 0.509244753016718, nan, 0.7969918168805727, 0.7695048984227154, nan, nan, 0.7600517872847073, 0.7499098964521397, 0.7115543714437794, nan, 0.7389294114691194, nan, nan, nan, 0.7830466003280568, 0.7711910792473967, 0.768795928139033, 0.7654150968352611, 0.7578130430360231, nan, 0.7573883538977724, 0.7491153863284192, 0.7475117939131363, 0.7413990258370459, 0.7391769182025528, 0.7373035829215634, 0.7363358548377923, 0.732410686499339

[nan, nan, nan, 0.8566702940484058, 0.8397477608966517, 0.8045781746395505, 0.8031246444936464, 0.7892555470337563, 0.7863571503595855, 0.7808257326933439, 0.7582582808395815, 0.7531435306845944, 0.7509519736076168, 0.7506666668292827, 0.7451143019575807, 0.7410062256610379, nan, nan, 0.7978915246012895, 0.775823957405158, 0.7704220880432769, 0.7542719877777837, 0.7470074370098737, 0.7364030183890934, 0.7301604059321612, 0.7288164153464676, 0.7241017812066775, 0.72017436457652, 0.7195775935425416, 0.7194004231555842, 0.7149925554791444, 0.7059610008957481, 0.7037240640009889, 0.6988733336206724, 0.6895789307733962, 0.686652093694321, 0.6848755659458555, 0.6838132317889543, 0.6810337057342037, 0.6739819082047033, 0.6662880687720514, 0.6613010777263401, 0.6534303481712734, 0.6528160235796964, 0.6509821621157732, 0.6485258299974733, 0.6466203261830851, 0.6441063769562798, 0.6408747743506072, 0.6304915620815764, 0.6299852644393675, 0.6151860117922717, 0.5976225649411496, 0.5946006954886484

[nan, nan, nan, 0.6681132824786996, 0.6241229681675602, nan, nan, nan, nan, 0.6289110477944775, 0.6067908225487821, 0.5719391904852668, nan, nan, 0.745798176301093, 0.6634252049191484, 0.6330949368564632, 0.6054010726618606, 0.6041364809829995, 0.5980792555520857, 0.5608311719972012, 0.5482312477696542, 0.5423894639866402, 0.5416190502314747, 0.49633737490600705, 0.49162389830431297, 0.4676275001698027, 0.44765359830888657, 0.44632426239092954, 0.4138912497712061, 0.3611901321125908, nan, nan, nan, nan, 0.6925512501904688, 0.6668261508808354, 0.6491861334438889, 0.6262741713956952, nan, 0.6314357570395855, nan, nan, nan, nan, 0.688721148236213, 0.6802216459944593, 0.6161351410423402, 0.6027597400218414, 0.5985546250233563, 0.5880160900129067, 0.5736132565707222, 0.5712228551942703, 0.5675499578413689, 0.5627922466979971, 0.5561486100249143, 0.5410030015364033, 0.5179151415853838, 0.49449606668869567, 0.4688459957856481, 0.4492001215619574, 0.4394416942664013, 0.433887342958746, 0.41778

[0.9018671724459548, 0.9016623964484268, 0.8967019558533583, 0.878948618337632, 0.8612116761302562, 0.8479155267940712, 0.8457079509464432, 0.8412885074411853, 0.8401029690567196, 0.8398574134806791, 0.8336006136440751, 0.8295703797709253, 0.8292766213087449, 0.8205254664372117, 0.8172470526192932, 0.8139283862664329, 0.8125282639853185, 0.8001836439434106, 0.7988026992031798, 0.7955582646000809, 0.7943652152035204, 0.7939805391296412, 0.7889389597930236, 0.7882942314823308, 0.78636093476078, 0.7851923364916887, 0.7843108307689153, 0.7832810377201758, 0.7811647624237826, 0.776815355356388, 0.7763681097449066, 0.7736650575241486, 0.7682741066154081, 0.7649657809637409, 0.7642981435532251, 0.7639516706464009, 0.7636550800510997, 0.7546317046420783, 0.7515172182113732, 0.7492062603148874, 0.7457801847820831, 0.7426789120762352, 0.7402716671345346, 0.7399234403812095, 0.7395563263160001, 0.7363950715029293, 0.735532607996641, 0.735081869484077, 0.7321484697502663, 0.7280780237248269, 0.726

[nan, 0.8371727158035699, 0.8061388633820433, 0.7834366575895876, 0.7818475912776175, 0.7770296677788767, 0.7702866819804849, 0.7669548869072617, 0.7633431484054622, 0.7330576271056477, 0.7307592725226685, 0.7280331294532392, 0.7268330158302333, 0.7127426321937893, 0.6943120399187709, 0.6910294609356844, 0.6907163155687306, 0.6883467687268343, 0.6835515508139748, 0.6769761639038683, 0.6752301940876873, 0.6742969940074748, 0.6667022533435629, 0.6545778372838319, 0.6488328058535741, 0.6448574625430167, 0.643621223072046, 0.6427075581087218, 0.6421607153033264, 0.6413555649869045, 0.6403448755626155, 0.6330258144222983, 0.6260874942369214, 0.6205240833540521, 0.6149235607692911, 0.61347456075969, 0.6076383615412035, 0.6043618110670126, 0.5979905493821839, 0.5887734531133507, 0.5808024432870742, 0.5746918231314497, 0.5703191410465239, 0.5680847341463964, 0.5600816681815848, 0.5585935053224773, 0.5488632912299573, 0.5466754432037964, 0.5414011279082095, 0.5371244179067824, 0.534969279695387

 26%|██▌       | 26/100 [00:00<00:00, 259.76it/s][0.7879038744189923, 0.7787855559456058, 0.7776919916055879, 0.7466150995375841, 0.7255593115863783, 0.7238851712303478, 0.7087609245579858, 0.7049514397479062, 0.6957871676967424, 0.6948804843068298, 0.6805302621635568, 0.6761439642251657, 0.6639869181655771, 0.6590964467327498, 0.658676468453474, 0.6392079330777785, 0.6364966367144481, 0.6354851988915482, 0.6348806846061076, 0.6330583354601725, 0.62816799287582, 0.6262371848836871, 0.6244357241975181, 0.6142816936135722, 0.6101691708214632, 0.5973327511573729, 0.5958800026896691, 0.5902342694121799, 0.5862032287107204, 0.569680689996212, 0.5302326494564609, 0.5228929009431322, 0.4801534757116609, 0.46850838343273604, 0.44791100820389984, 0.3934194794222232, nan, 0.7735029534275396, 0.7623964296084927, 0.7590860760051913, 0.7573138020712048, 0.7486269318154548, 0.7479624694037696, 0.7470683405934478, 0.7358825440195094, 0.7289105956004213, 0.7284529820528155, 0.7246647044103857, 0.71250

[nan, nan, nan, nan, nan, nan, 0.7595026954608421, nan, 0.7762077824073548, 0.7693139935642966, 0.7593546968317384, 0.7498131228935064, 0.7457375975334172, 0.7440228948134365, 0.7381627276161193, 0.727343017319367, 0.720981937669874, 0.7075994334671927, 0.7020828842557156, 0.7001306201930666, 0.6995565742587215, 0.6994685275592448, 0.6979540634940153, 0.6856943260784825, 0.6832379796960346, 0.6788603817671103, 0.6731630120634549, 0.6671833655144408, 0.6627139346111546, 0.6609596439718337, 0.6496886031804172, 0.6481277662095478, 0.6406512043190576, 0.640361294392373, 0.6301877563711061, 0.6296212233641475, 0.6252479954043759, 0.6170047988609763, 0.6111557839831054, 0.6057781390834002, 0.5975230988915179, 0.5956210979305318, 0.5904796102295377, 0.5707409665529625, 0.56489509071878, 0.5523704827843966, 0.5432507224255979, 0.5301087344633846, 0.5068982270326906, 0.4969609538342606, 0.38339512862297087, nan, 0.7804401573434883, 0.7509356839042941, 0.7374081799452846, 0.7370666943417566, 0.7

[nan, 0.8811386027939149, 0.8655267550361744, 0.8641252238747367, 0.8556504631264512, nan, 0.8377165148490784, 0.8307645040383465, 0.8222462243578514, 0.8148220168747995, 0.7877454629118921, 0.7644703206889744, 0.7474523448966821, 0.7470891808573764, 0.7423139601684846, 0.7329306379701767, 0.7297812049891828, 0.6908465957106413, 0.6597658399931083, 0.6568752665816806, 0.5397414677471609]
[nan, 0.6585668900816023, nan, nan, nan, 0.7051248323551533, nan, nan, nan, nan, 0.7339558143593007, 0.7144076275883471, nan, 0.7079606557863225, 0.7021668453356638, 0.6921380593085789, 0.6749072415986935, 0.673861837765203, 0.64713718027717, 0.611215196284825, 0.5792517405663136, nan, 0.6556918789944153, 0.6465684214682343, 0.6442980051748696, 0.6419906307245893, 0.6404590280307493, 0.6394721969852791, 0.626065103684206, 0.5969065204786487, 0.5814630741178947, 0.5814126365836821, 0.5774813161940827, 0.5702376874590298, 0.5569297590842595, 0.5459908128077459, 0.5346242539612931, 0.5339398375164096, 0.5

 52%|█████▏    | 52/100 [00:00<00:00, 259.73it/s][0.7536620414649252, nan, 0.8701906455537661, 0.7474941803904097, 0.7426608716354557, 0.7424973264971868, 0.7383628327101597, 0.7153968048903873, 0.7033085066266807, 0.6959491319554818, 0.6848424145921506, 0.683789371573257, 0.6691420983420271, 0.5864413292557684, 0.5799719151467889, 0.5371007840954591, 0.5156614295611324, 0.5150358408021446, 0.5128968508844474, 0.5109072652138152, 0.5077459769278295, 0.4646898376713338, 0.4512720028396855, 0.4051738033721417, 0.24087788003165042]
[0.8379342688003291, 0.8217979934444365, 0.8212082782378072, nan, 0.8206479842719945, 0.8174639863667452, 0.8057725415405811, 0.803468963199853, 0.798322272365944, 0.7842956868612321, 0.7812126317883437, 0.7806546698201347, 0.7716692873760278, 0.7639871686882544, 0.762509650736243, 0.762232899329291, 0.7608893936692767, 0.7524106099279648, 0.7498271507023827, 0.7472211655374321, 0.7453686458650828, 0.7434962207965384, 0.7425451317988049, 0.7340569551331296, 0.7

 76%|███████▌  | 76/100 [00:00<00:00, 253.21it/s][nan, nan, nan, nan, nan, nan, 0.7651097998065617, 0.7645540696096431, 0.7638910252454522, 0.7619545419937388, 0.7589126209468223, 0.7495849346118466, 0.7397082925398402, 0.7370521363629318, 0.7220949862677943, 0.7076471733204176, 0.7067506741374447, 0.7046599886223844, 0.6917706428021976, 0.6779285433038074, 0.6670962471184704, 0.6669109973208877, 0.6640605797656903, 0.6590314759767345, 0.6585734159451043, 0.646576420204136, 0.6353117049769801, 0.6315252921152968, 0.6266309768636159, 0.6207326314145177, 0.6138868383900838, 0.6053447659071683, 0.6005664389458181, 0.5860985381674727, 0.5828167299924119, 0.5814668124085197, 0.5479572640116207, 0.5456013746451253, 0.5323947633365712, 0.5309756963592948, 0.49964065660310375, 0.4950496713887802, 0.4496933440224009, nan, nan, 0.7791522725020035, 0.6760505973387274, 0.6691773341355925, 0.6618371327484319, 0.6580739176178497, 0.6572916461107073, 0.6341760320907657, 0.6298575228435815, 0.62807661

[0.8653404633530631, 0.7990318981615753, 0.7937578353479273, 0.7912448519947516, 0.7910348215495449, 0.7828147932821937, 0.7705730140937793, 0.7693856353313059, 0.7413624663603168, 0.7031632525113968, 0.6585450144591992, 0.6310751699124998]
[nan, 0.8711942471393684, 0.8585627205140547, 0.8421526188327887, 0.8372529563749324, 0.8323183815599257, 0.8320527602105372, 0.8278461467058938, 0.82496245533927, 0.8214290485187468, 0.7994778469083934, 0.7941137750537025, 0.7834189187472622, 0.7832207888307923, 0.7768501170953844, 0.7765312012945104, 0.7748968337810248, 0.7680639766924289, 0.7666748905660146, 0.7533032448377647, 0.7508171381537458, 0.7501885304576331, nan, 0.8128574987568985, 0.7910853957822503, 0.7490101463873832, 0.7423741987245088, 0.7402817609687512, 0.7396720519544705, 0.7392256570900633, 0.724519064693464, 0.7008849896172364, 0.700340750019241, 0.699087291747393, 0.6934149644149215, 0.6761097139253076, 0.6726739111405928, 0.6666819879010194, 0.6608651514179775, 0.65489820771

100%|██████████| 100/100 [00:00<00:00, 249.90it/s]
ModelEvaluation(mapk=0.5106169123204836)
****************************************************************************************************
Taking split number 4 as the validation set...
Train indexes: 5901935
Val indexes: 1475483
Users: 30530
Building sessions
Building vocabulary
Training word2vec
Word2Vec(vocab=88778, size=32, alpha=0.03)
Building user embeddings
5901935it [05:19, 18499.67it/s]
Evaluating the model
  0%|          | 0/100 [00:00<?, ?it/s][nan, nan, 0.8794749257311952, 0.8523384733314675, 0.821026383243678, 0.791634119710323, 0.7843615890639304, 0.7732559365902857, 0.7624290350725843, 0.7508431691779635, 0.74772590061717, 0.7392992222094551, 0.7360930775659792, 0.7298077762971446, 0.7284316095128958, 0.6862790535686208, 0.6841683475019483, 0.6824904264879099, 0.6802553772517886, 0.6769715234796393, 0.671581710057559, 0.6544830180857207, 0.6298385952233994, 0.6226835989341974, 0.6100345691105811, 0.5942608953329751, 0

 25%|██▌       | 25/100 [00:00<00:00, 248.53it/s][nan, 0.7131884410694357, 0.6992783963731793, nan, nan, nan, 0.7267059008519747, 0.7052029364132752, 0.6589873844063665, 0.6535547309108504, 0.6413513550919259, 0.6233005845956995, 0.6160346871866067, 0.6120101095411589, 0.6046241025108909, 0.6043095588989797, 0.595643760305019, 0.5809005484619805, 0.5767236023811924, nan, 0.5848848955630707, 0.5674193695403498, 0.5489821758777897, 0.5448515833951045, 0.5240163257942465, 0.5151793889350532, 0.49084483684809427, 0.4719959737701622, 0.4369008496400027]
[0.7945342876393617, nan, 0.8377886709822936, 0.7869805378590664, 0.7575654646854109, 0.7571183134526851, 0.7519200224375696, 0.7434024569639205, 0.7368167246083235, 0.7261534529647495, 0.7206720347701993, 0.7200734814490791, 0.7137378467422811, 0.6834465288466436, 0.6752206380565603, 0.6641124609929393, 0.6602267743840189, 0.6571311835761896, 0.649272383354012, 0.6482528890633424, 0.6397544344297498, 0.6222571927807347, 0.5644725158475669, 

[nan, nan, nan, nan, nan, nan, nan, 0.7421204763686592, nan, nan, 0.8368736843674052, 0.8289852144369647, 0.8286543814326518, 0.7966463456886423, 0.7928645218749137, 0.7860747110168291, 0.7853360665178076, 0.7599885081302991, 0.7541481103913962, 0.7458089657497518, 0.7391177620854245, 0.7350152651100816, 0.7342781508872904, 0.7096750944904947, 0.7057270931098621, 0.6998635766643434, 0.6988091295412975, 0.6966720276088302, 0.6950229384379173, 0.6865533097828345, 0.6840186622130244, 0.6786077796419592, 0.6746476729527827, 0.66993380091249, 0.6669933040946213, 0.6201173366440453, 0.6152690123820257, nan, 0.6659348221324578, 0.6484790641463419, 0.6473828651783845, 0.6152375454390562, 0.6141907697124799, 0.5727192900344206, 0.5703343553497667, 0.567632978551387, 0.5581106093390691, 0.551360216104979, 0.5311047025454613, 0.5226222505020847, 0.4911223397081761, 0.4770951384593527, 0.46929951685185556, 0.36858576051842323, nan]
[0.7306120432332104, 0.7120365636746431, 0.6963597674716523, nan, 

[nan, 0.7205100926156559, 0.7158447563276011, 0.7080930538858677, 0.7064911238250287, 0.7020395037697873, 0.7001524461709508, 0.6981431795789999, 0.6969104660090479, 0.6960465379568326, 0.6953422185708651, 0.6949431581321264, 0.6879003590894409, 0.687334432810566, 0.6868736260055921, 0.6833455624299376, 0.6813123438851199, 0.6769324110423124, 0.6754203617633592, 0.6736587879561962, 0.6723523168124094, 0.6714198260645168, 0.6651116421585124, 0.6647660907108504, 0.6599779064503044, 0.6581803839263736, 0.6553089941230902, 0.6512346151422664, 0.6508019814068217, 0.649463582223339, 0.6467535031181195, 0.6431548527557498, 0.6418763620623342, 0.6394303833676069, 0.633544399046844, 0.6318729496811278, 0.6305761110744893, 0.6304885228090847, 0.6298726501552181, 0.6278653686930451, 0.6230821931289323, 0.6212803653295413, 0.6209653767738261, 0.6174066623573103, 0.6173998835700385, 0.6159296457709599, 0.6106106771956203, 0.6097785453791617, 0.6027231946578395, 0.6023565561708141, 0.590546238888136

[nan, nan, nan, 0.8121675565155841, 0.7944358736674868, 0.7659247882478092, 0.7576266677056643, 0.7492484570135022, 0.7483670045137306, 0.7407676608891249, 0.7220953791754403, 0.720559359186495, 0.7163742004491472, 0.6996214638125405, 0.6984344578889924, 0.6971149917404179, 0.6843375781218579, 0.6554473301928195, 0.6533745188315444, 0.6442695225420324, 0.6427528350506483, 0.6391240770093746, 0.6324799931231978, 0.597230127098298, 0.5885176773671721, 0.5747486016595971, 0.566237706831646, 0.558073540232478, 0.5547198055654895, 0.5491982114886106, 0.5389612296565021, 0.5360028839150813, 0.5357975556712544, 0.5351149571063023, 0.5329584157046483, 0.5298660401976223, 0.5247000805436584, 0.5234074873244056, 0.5199622186833137, 0.5095803663440239, 0.5038175296641269, 0.450114818285353, 0.4469049571597042, 0.4021666073944868, nan, 0.861244442798762, 0.8556430351166123, 0.821771693216604, 0.8137149361937895, 0.8072194618310295, 0.7488481168885738, 0.7393771075432154, 0.7326026313870629, 0.7169

[0.8311368828078248, 0.7910228474504166, 0.7891911385776411, 0.7729156018410451, 0.7725577677811655, nan, nan, 0.8019585980097746, 0.7693521728941188, 0.7690256977749819, 0.7548445114748321, 0.7446386059588458, 0.7434746189726005, 0.7352575273814992, 0.7216746790494872, 0.7110650187161687, 0.7104466657179358, 0.7085271602156674, 0.7072947694042748, 0.6951949707920358, 0.6949229261208484, 0.6946311111405689, 0.6862091403754799, 0.6858384518907156, 0.6773142115703877, 0.6760490410221097, 0.6747505048361082, 0.6656335830973479, 0.6597058097178031, 0.6512741102154219, 0.6496642688460976, 0.6444260783014085, 0.6384941083163684, 0.6346460878810253, 0.6298337347460735, 0.6212318054987239, 0.6144321139829826, 0.6138110751657302, 0.6007825398756074, 0.5945745188884017, 0.5940997271184999, 0.5933211990680983, 0.5926682063035389, 0.5922478372699062, 0.5917125944073764, 0.5901534465173315, 0.5894715318403163, 0.5838747826937217, 0.5671209422296074, 0.5440947894987579, 0.5176430715882991, 0.4980701

100%|██████████| 100/100 [00:00<00:00, 258.46it/s]
ModelEvaluation(mapk=0.5168065192743764)


0.5061806415343915

## Word2Vec + GBDT

In [39]:
def get_user_song_similarity(wv, row):
  global user_embeddings
  user_id = row['msno']
  song_id = row['song_id']
  user_embedding = user_embeddings[user_id]
  song_embedding = wv[song_id] if song_id in wv else np.zeros(WORD2VEC_EMBEDDING_SIZE)
  return word2vec_similarity(user_embedding, song_embedding)

def extend_data_with_word2vec(word2vec, X):
  wv = word2vec.wv
  X = X.copy()
  X['w2v_sim'] = X.apply(lambda row: get_user_song_similarity(word2vec.wv, row), axis=1)
  return X

def train_word2vec_lightgbm(X_train, y_train):
  word2vec = train_word2vec(X_train, y_train)
  X_train = extend_data_with_word2vec(word2vec, X_train)
  model = train_lightgbm(X_train, y_train)
  return word2vec, model

In [40]:
evaluations = k_fold_cross_validation(
  train_df,
  k=CROSS_VALIDATION_K,
  train_model_fn=train_word2vec_lightgbm,
  preprocess_val_fn=lambda models, X: extend_data_with_word2vec(models[0], X),
  build_ranking_fn=lambda models: lambda user_id, songs_df: lightgbm_rank_by_user(models[1], user_id, songs_df))
mapks = list(map(lambda evaluation: evaluation.mapk, evaluations))
np.mean(mapks)

****************************************************************************************************
Taking split number 0 as the validation set...
Train indexes: 5901934
Val indexes: 1475484
Users: 30548
Building sessions
Building vocabulary
Training word2vec
Word2Vec(vocab=88866, size=32, alpha=0.03)
Building user embeddings
5901934it [05:18, 18507.99it/s]
Training the model
[10]	training's auc: 0.763694
[20]	training's auc: 0.774459
[30]	training's auc: 0.77925
[40]	training's auc: 0.783557
[50]	training's auc: 0.786928
Evaluating the model
100%|██████████| 100/100 [03:48<00:00,  2.28s/it]
ModelEvaluation(mapk=0.6797367929579239)
****************************************************************************************************
Taking split number 1 as the validation set...
Train indexes: 5901934
Val indexes: 1475484
Users: 30543
Building sessions
Building vocabulary
Training word2vec
Word2Vec(vocab=88921, size=32, alpha=0.03)
Building user embeddings
5901934it [05:19, 18499.12it/s

0.6846374140211641

In [41]:
X_train = train_df.drop(['target'], axis=1)
y_train = train_df['target'].values
word2vec, model = train_word2vec_lightgbm(X_train, y_train)
X_train = extend_data_with_word2vec(word2vec, X_train)
evaluation = evaluate_model_ranking_fn(
  lambda user_id, songs_df: lightgbm_rank_by_user(model, user_id, songs_df),
  X_train,
  y_train)

Users: 30755
Building sessions
Building vocabulary
Training word2vec
Word2Vec(vocab=101648, size=32, alpha=0.03)
Building user embeddings
7377418it [06:41, 18375.53it/s]
Training the model
[10]	training's auc: 0.761852
[20]	training's auc: 0.771767
[30]	training's auc: 0.776691
[40]	training's auc: 0.780748
[50]	training's auc: 0.784544
Evaluating the model
100%|██████████| 100/100 [03:47<00:00,  2.27s/it]
ModelEvaluation(mapk=0.8461059523809525)


In [42]:
np.random.seed(SEED)
explain_indexes = np.random.choice(X_train.index, EXPLAIN_FROM_SAMPLES, replace=False)
X_explain = X_train.iloc[explain_indexes]

In [43]:
explainer = shap.TreeExplainer(model, feature_perturbation='tree_path_dependent')
shap_values = explainer.shap_values(X_explain)

LightGBM binary classifier with TreeExplainer shap values output has changed to a list of ndarray


In [44]:
index = 0
shap.initjs()
shap.force_plot(
  explainer.expected_value[0],
  shap_values[0][index,:],
  X_explain.iloc[index,:])

In [45]:
X_test = test_df.drop(['id'], axis=1)
X_test = extend_data_with_word2vec(word2vec, X_test)
test_ids = test_df['id'].values

p_test = model.predict(X_test)

submission_df = pd.DataFrame()
submission_df['id'] = test_ids
submission_df['target'] = p_test
submission_df.to_csv('lightgbm_word2vec_submission_df.csv', index=False, float_format='%.6f')

In [46]:
# TOOD(niksaz): Update it to print the time taken to run the notebook.

from datetime import datetime

now = datetime.now()

current_time = now.strftime("%H:%M:%S")
print("Current Time =", current_time)

Current Time = 04:10:23
