In [9]:
import glob
import os
from numpy import NaN
import pandas as pd
from gensim.models.keyedvectors import KeyedVectors
from sklearn.preprocessing import LabelBinarizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.multioutput import MultiOutputClassifier

In [10]:
files = glob.glob(os.path.join("../../data/result", "*.csv"))
df_list = []
for file in files:
    tmp_df = pd.read_csv(file)
    tmp_df['filename'] = os.path.basename(file)
    df_list.append(tmp_df)
df = pd.concat(df_list, ignore_index=True)

In [11]:
# predの要素で, propertyが含まれる値を返す
df_prop = df.query('pred.str.contains("property")', engine='python')
# 'wikiPage'を含まないpropertyを返す
df_prop = df_prop[~df_prop['pred'].str.contains('wikiPage')]
df_prop = df_prop[~df_prop['pred'].str.contains('画像')]
top_10_df_prop = df_prop.groupby('pred').count().sort_values(['obj'], ascending=False).head(10)
top_10_prop_list = top_10_df_prop.index.to_list()
top_10_df = df[df['pred'].isin(top_10_prop_list)]
prop_to_label = dict(zip(top_10_prop_list, range(0,10)))

In [12]:
top_10_prop_list

['http://ja.dbpedia.org/property/隣接都道府県',
 'http://ja.dbpedia.org/property/before',
 'http://ja.dbpedia.org/property/歌など',
 'http://ja.dbpedia.org/property/所在地',
 'http://ja.dbpedia.org/property/説明',
 'http://ja.dbpedia.org/property/表記',
 'http://ja.dbpedia.org/property/years',
 'http://ja.dbpedia.org/property/after',
 'http://ja.dbpedia.org/property/シンボル名',
 'http://ja.dbpedia.org/property/title']

In [13]:
model = KeyedVectors.load('../../Models/japanese-word2vec-model-builder/word2vec.gensim.model')

In [14]:
def vectorize(model, word):
    try:
        output = model.wv[word]
        return output
    except:
        return "NaN"

In [15]:
def preprocessing(obj):
    if type(obj) != str:
        output = obj
    elif len(obj.split("：")) != 1:
        output = obj.split("：")[-1]
    else:
        output = obj.split("/")[-1]
    return output

In [16]:
# vectorization
cp_top_10_df = top_10_df.copy()
cp_top_10_df['key_vec'] = cp_top_10_df['key'].map(lambda x:vectorize(model,x))
cp_top_10_df['label'] = cp_top_10_df['pred'].map(prop_to_label)
obj_list = cp_top_10_df['obj'].map(lambda x:preprocessing(x))
cp_top_10_df['obj_vec'] = obj_list.map(lambda x:vectorize(model,x))

In [17]:
top_10_df_vec = pd.merge(top_10_df,cp_top_10_df, left_index=True, right_on=top_10_df.index)
top_10_df_vec = top_10_df_vec.reset_index(drop=True)

In [18]:
top_10_df_vec = top_10_df_vec.drop(["key_0","filename_x","key_y","pred_y","obj_y","filename_y"], axis=1)

In [19]:
#top_10_df_vec_exclusion_null = top_10_df_vec[top_10_df_vec['obj_vec'] != "Null"].reset_index(drop=True)
top_10_df_vec_exclusion_nan = top_10_df_vec[top_10_df_vec['obj_vec'] != "NaN"].reset_index(drop=True)
#top_10_df_vec_exclusion_null

  result = libops.scalar_compare(x.ravel(), y, op)


In [20]:
top_10_df_vec_exclusion_nan

Unnamed: 0,key_x,pred_x,obj_x,key_vec,label,obj_vec
0,沖縄県,http://ja.dbpedia.org/property/after,http://ja.dbpedia.org/resource/琉球列島米国軍政府,"[0.03427477, 0.17860307, -0.041606415, 0.09745...",7,"[-0.13429342, -0.09761905, 0.068803005, 0.0482..."
1,沖縄県,http://ja.dbpedia.org/property/before,http://ja.dbpedia.org/resource/琉球列島米国民政府,"[0.03427477, 0.17860307, -0.041606415, 0.09745...",1,"[-0.08374447, 0.07326782, -0.10170767, 0.10177..."
2,沖縄県,http://ja.dbpedia.org/property/before,http://ja.dbpedia.org/resource/琉球政府,"[0.03427477, 0.17860307, -0.041606415, 0.09745...",1,"[-0.050090242, 0.18710567, -0.044265315, 0.181..."
3,沖縄県,http://ja.dbpedia.org/property/before,http://ja.dbpedia.org/resource/琉球藩,"[0.03427477, 0.17860307, -0.041606415, 0.09745...",1,"[-0.20956942, -0.06843233, -0.041769836, -0.07..."
4,沖縄県,http://ja.dbpedia.org/property/years,1879,"[0.03427477, 0.17860307, -0.041606415, 0.09745...",6,"[-0.10014486, 0.016131686, -0.15636334, 0.0958..."
...,...,...,...,...,...,...
586,高知県,http://ja.dbpedia.org/property/表記,次,"[-0.021919195, 0.10427152, 0.04427553, 0.12341...",5,"[-0.12088986, 0.10902696, -0.06065345, 0.07214..."
587,高知県,http://ja.dbpedia.org/property/歌など,県の魚：カツオ,"[-0.021919195, 0.10427152, 0.04427553, 0.12341...",2,"[-0.1669918, 0.011190996, 0.08135691, 0.178282..."
588,高知県,http://ja.dbpedia.org/property/隣接都道府県,http://ja.dbpedia.org/resource/大分県,"[-0.021919195, 0.10427152, 0.04427553, 0.12341...",0,"[0.01619406, 0.14829713, 0.05118972, 0.1368117..."
589,高知県,http://ja.dbpedia.org/property/隣接都道府県,http://ja.dbpedia.org/resource/徳島県,"[-0.021919195, 0.10427152, 0.04427553, 0.12341...",0,"[0.013457059, 0.049274415, 0.034320407, 0.1147..."


In [21]:
input_data = list(top_10_df_vec_exclusion_nan['obj_vec'].values)
label_data = top_10_df_vec_exclusion_nan['label'].values
label_dence = LabelBinarizer().fit_transform(label_data)
X_train, X_test, y_train, y_test = train_test_split(input_data, label_dence)

In [22]:
forest = RandomForestClassifier(n_estimators = 10,random_state=1)
#multi_target_forest = MultiOutputClassifier(forest, n_jobs=-1)

In [28]:
forest.fit(input_data, label_dence)

RandomForestClassifier(n_estimators=10, random_state=1)

In [29]:
y_pred = forest.predict(input_data)
y_pred_proba = forest.predict_proba(input_data)

In [31]:
from sklearn.metrics import accuracy_score
accuracy_score(label_dence, y_pred)

0.9627749576988156

In [36]:
y_pred[4]

array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0])

In [19]:
y_pred_proba[0]

array([[1.        , 0.        ],
       [0.1       , 0.9       ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.16785714, 0.83214286],
       [1.        , 0.        ],
       [0.2       , 0.8       ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [0.8       , 0.2       ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.23206349, 0.76793651],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [1.        , 0.        ],
       [0.        , 1.        ],
       [0.        , 1.        ],
       [0.9       , 0.1       ],
       [0.1       , 0.9       ],
       [0.4       , 0.6       ],
       [0.1       , 0.9       ],
       [1.