In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
tqdm.pandas()
import os
import gc
import random
from glob import glob
from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold
import warnings
import seaborn as sns
import pickle
import json
import re
import time
import sys
from requests import get
import multiprocessing
import joblib

from pathlib import Path

from gensim.models import word2vec, KeyedVectors

class CFG:
    seed = 46
    target = "point_of_interest"
    n_neighbors = 10
    n_splits = 3

    expID = ""
    if "google.colab" in sys.modules:
        expID = get("http://172.28.0.2:9000/api/sessions").json()[0]["name"].split(".")[0].split("-")[0]

random.seed(CFG.seed)
os.environ["PYTHONHASHSEED"] = str(CFG.seed)
np.random.seed(CFG.seed)

plt.rcParams["font.size"] = 13
warnings.filterwarnings('ignore')

%cd /content/drive/MyDrive/Kaggle/Foursquare/Notebook

/content/drive/MyDrive/Kaggle/Foursquare/Notebook


In [None]:
!pip install texthero
import texthero as hero

In [None]:
import nltk

nltk.download('stopwords')
os.listdir(os.path.expanduser('~/nltk_data/corpora/stopwords/'))

In [None]:
custom_stopwords = []
for lang in ['english',
             'azerbaijani',
             'danish',
             'arabic',
             'russian',
             'finnish',
             'portuguese',
             'greek',
             'swedish',
             'french',
             'dutch',
             'spanish',
             'nepali',
             'indonesian',
             'german',
             'hungarian',
             'turkish',
             'italian',
             'norwegian',
             'romanian']:
    custom_stopwords += nltk.corpus.stopwords.words(lang)

In [19]:
train = pd.read_csv("../Input/train.csv")
test = pd.read_csv("../Input/test.csv")
test[CFG.target] = "TEST"

train.head(1)

Unnamed: 0,id,name,latitude,longitude,address,city,state,zip,country,url,phone,categories,point_of_interest
0,E_000001272c6c5d,Café Stad Oudenaarde,50.859975,3.634196,Abdijstraat,Nederename,Oost-Vlaanderen,9700,BE,,,Bars,P_677e840bb6fc7e


In [31]:
train['name'] = train['name'].fillna("NaN")
test['name'] = test['name'].fillna("NaN")
train['categories'] = train['categories'].fillna("NaN")
test['categories'] = test['categories'].fillna("NaN")

In [32]:
S = set()
for x in tqdm(train['name'].map(lambda x: set(x.split()))):
    S |= x
len(S)

  0%|          | 0/1138812 [00:00<?, ?it/s]

535344

In [33]:
S = set()
for x in tqdm(train['categories'].map(lambda x: set(x.split()))):
    S |= x
len(S)

  0%|          | 0/1138812 [00:00<?, ?it/s]

1232

In [36]:
# 単語ベクトル表現の次元数
# 元の語彙数をベースに適当に決めました
model_size = {
    "name": 768,
    "categories": 24
}

n_iter = 100

In [37]:
w2v_dfs = []
for df_name in ('name', 'categories'):
    df = train[['id', df_name]].copy()
    df[df_name] = df[df_name].map(lambda x: x.split())
    # Word2Vecの学習
    w2v_model = word2vec.Word2Vec(df[df_name].values.tolist(),
                                  size=model_size[df_name],
                                  min_count=1,
                                  window=1,
                                  iter=n_iter)

    # 各文章ごとにそれぞれの単語をベクトル表現に直し、平均をとって文章ベクトルにする
    sentence_vectors = df[df_name].progress_apply(
        lambda x: np.mean([w2v_model.wv[e] for e in x], axis=0))
    sentence_vectors = np.vstack([x for x in sentence_vectors])
    sentence_vector_df = pd.DataFrame(sentence_vectors,
                                      columns=[f"{df_name}_w2v_{i}"
                                               for i in range(model_size[df_name])])
    sentence_vector_df.index = df["id"]
    w2v_dfs.append(sentence_vector_df)

  0%|          | 0/1138812 [00:00<?, ?it/s]

  0%|          | 0/1138812 [00:00<?, ?it/s]

In [38]:
w2v_dfs[0]

Unnamed: 0_level_0,name_w2v_0,name_w2v_1,name_w2v_2,name_w2v_3,name_w2v_4,name_w2v_5,name_w2v_6,name_w2v_7,name_w2v_8,name_w2v_9,...,name_w2v_758,name_w2v_759,name_w2v_760,name_w2v_761,name_w2v_762,name_w2v_763,name_w2v_764,name_w2v_765,name_w2v_766,name_w2v_767
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
E_000001272c6c5d,-0.128070,-0.290525,0.377483,-0.298867,-0.541073,0.139544,0.180281,0.117893,0.211447,0.142901,...,0.411194,0.091078,-0.529303,0.089574,0.168886,0.484387,-0.132952,0.307027,-0.045498,-0.092591
E_000002eae2a589,-0.111135,-0.129628,-0.281206,-0.299045,-0.068080,0.304450,-0.140774,0.155587,0.088788,0.332196,...,0.014041,-0.020166,-0.300935,-0.083291,-0.091328,0.049496,-0.257144,-0.369133,0.064085,-0.024180
E_000007f24ebc95,0.000020,-0.000512,-0.000275,0.000277,0.000059,-0.000119,-0.000646,-0.000367,0.000524,0.000004,...,0.000622,0.000194,0.000163,0.000283,0.000415,-0.000469,-0.000553,-0.000313,-0.000419,0.000210
E_000008a8ba4f48,0.161576,-0.111229,-0.579936,-0.190412,0.754001,-0.319698,-0.658304,-0.166837,0.325079,1.353685,...,0.315053,0.295706,0.052273,-0.245261,-0.361820,-0.539114,0.063634,-0.115709,-0.411379,-0.588273
E_00001d92066153,0.010899,-0.309584,-0.026529,-0.657904,-0.878519,0.685841,-0.271493,-0.227291,0.599435,0.488928,...,-0.016102,0.187755,-0.769442,-0.304488,-0.079340,0.356391,0.298493,-0.190693,0.119596,-0.089197
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
E_ffffb80854f713,0.028108,0.017983,-0.038874,-0.057113,0.055869,0.024891,-0.038349,-0.007173,0.008841,0.037868,...,0.072911,0.048870,-0.048269,-0.061853,-0.008384,0.040424,-0.020579,-0.021909,-0.018071,-0.065743
E_ffffbf9a83e0ba,-0.337830,-0.021620,0.057316,-0.478716,-0.226242,-0.048042,0.293066,0.044944,0.005972,0.220146,...,0.220861,-0.287470,-0.048429,0.056402,0.318056,0.538215,0.121695,0.351379,0.541854,0.483025
E_ffffc572b4d35b,-0.021009,-0.251012,-0.079286,-0.378414,0.616468,-0.058828,-0.409393,-0.332349,0.019730,0.225122,...,0.661314,0.042429,-0.438725,0.055647,-0.088321,0.074090,-0.362868,-0.427179,-0.267799,-0.185070
E_ffffca745329ed,0.100873,0.176165,-0.239191,-0.026140,0.182148,0.262656,-0.064830,0.064936,0.228181,-0.084772,...,0.138297,0.080083,-0.072419,-0.330473,-0.054541,0.187272,-0.155569,0.078622,-0.124151,-0.334205


In [39]:
w2v_dfs[1]

Unnamed: 0_level_0,categories_w2v_0,categories_w2v_1,categories_w2v_2,categories_w2v_3,categories_w2v_4,categories_w2v_5,categories_w2v_6,categories_w2v_7,categories_w2v_8,categories_w2v_9,...,categories_w2v_14,categories_w2v_15,categories_w2v_16,categories_w2v_17,categories_w2v_18,categories_w2v_19,categories_w2v_20,categories_w2v_21,categories_w2v_22,categories_w2v_23
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
E_000001272c6c5d,-2.787149,-2.014940,2.261574,0.980031,-3.634882,-4.760931,3.709704,4.783492,3.558110,-0.103200,...,-2.865034,-0.167246,1.188223,-2.044157,-1.468691,0.783087,-1.835027,2.561566,-6.316105,-0.100111
E_000002eae2a589,-2.269612,-1.646501,0.746487,-2.090281,-2.077449,0.632139,-0.555835,1.950902,0.373822,0.823460,...,-0.105707,1.742347,-0.741433,-0.965941,0.513004,1.246182,-0.946132,2.053953,-1.070638,-1.393208
E_000007f24ebc95,-1.633180,-1.691462,-2.221207,2.834297,0.189379,-0.309728,-1.858372,-2.293568,3.508495,0.816830,...,-0.984072,0.439374,4.001382,1.675677,-3.523839,5.382437,1.738346,-1.505293,5.066903,-2.819421
E_000008a8ba4f48,-2.077633,-0.008327,1.809032,1.420867,-2.539398,3.490407,0.145626,3.504977,1.559026,1.632184,...,0.751571,-0.153867,-1.062664,3.714954,1.890597,0.430341,-4.050965,-1.129406,-0.970262,1.917486
E_00001d92066153,-2.186944,-1.074642,1.202924,-1.354595,-2.415472,-0.275457,-0.523377,1.757013,0.377464,1.111827,...,-0.241873,1.331262,-2.044493,-1.337900,0.302813,1.077792,-0.771823,2.195816,-0.588790,-1.273589
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
E_ffffb80854f713,0.014425,-0.017644,-0.010850,0.004244,0.009522,0.014425,-0.000465,-0.000319,-0.013402,0.006294,...,0.012030,-0.007480,-0.011331,-0.014107,0.020622,-0.015852,-0.005683,-0.014461,-0.012732,0.007215
E_ffffbf9a83e0ba,-1.222658,-2.475937,-5.221395,1.586696,-2.491172,-2.766210,0.052481,4.792562,1.228772,1.099474,...,-1.533191,6.017534,3.753853,-0.850221,-2.447177,1.492226,0.736222,-1.280108,-1.138460,-2.350311
E_ffffc572b4d35b,-4.714374,2.969258,3.298482,-1.390142,-1.597502,0.351776,-2.555738,1.405403,0.857619,0.783624,...,0.678685,2.972791,1.595950,1.480514,3.194741,-1.049488,-0.202349,1.530228,1.554664,1.964149
E_ffffca745329ed,1.868375,0.028487,-2.543487,-3.811417,0.900811,0.840862,-0.046999,1.043458,0.117141,-1.757565,...,-0.648088,5.172011,-0.158588,-2.919411,-1.192981,-1.041145,-1.129058,3.525371,-4.407477,2.238769
