In [1]:
import sys
sys.getdefaultencoding()

'utf-8'

In [2]:
# -*- coding: utf-8 -*-
# python3
#
import argparse
import gzip
import math
import numpy
import re
import sys
import numpy as np
from copy import deepcopy
import codecs

In [3]:
isNumber = re.compile(r'\d+.*')
def norm_word(word):
    if isNumber.search(word.lower()):
        return '---num---'
    elif re.sub(r'\W+', '', word) == '':
        return '---punc---'
    else:
        return word.lower()

In [4]:
"""Read all the word vectors and normalize them"""
def read_word_vecs(filename):
    wordVectors = {}
    # ファイル読み込み
    if filename.endswith('.gz'): 
        fileObject = gzip.open(filename, 'r')
    else: 
        fileObject = codecs.open(filename, "r", "utf-8", 'ignore')
        
    for line in fileObject:
        # line = line.strip().lower()
        line = line.strip()
        word = line.split()[0]
        wordVectors[word] = numpy.zeros(len(line.split())-1, dtype=float)
        for index, vecVal in enumerate(line.split()[1:]):
            wordVectors[word][index] = float(vecVal)
        """normalize weight vector"""
        wordVectors[word] /= math.sqrt((wordVectors[word]**2).sum() + 1e-6)

    sys.stderr.write("Vectors read from: "+filename+" \n")
    return wordVectors

In [5]:
"""Write word vectors to file"""
def print_word_vecs(wordVectors, outFileName):
    sys.stderr.write('\nWriting down the vectors in '+outFileName+'\n')
    outFile = open(outFileName, 'w')  
    for word, values in wordVectors.items():
        outFile.write(word+' ')
        for val in wordVectors[word]:
            outFile.write('%.4f' %(val)+' ')
        outFile.write('\n')      
    outFile.close()

In [6]:
"""Read the PPDB.etc word relations as a dictionary"""
def read_lexicon(filename):
    lexicon = {}
    fileObject = open(filename, 'r')
    for line in fileObject:
        words = line.lower().strip().split()
        lexicon[norm_word(words[0])] = [norm_word(word) for word in words[1:]]
    return lexicon

In [7]:
"""Retrofit word vectors to a lexicon"""
def retrofit(wordVecs, lexicon, numIters):
    # Input word vecs
    newWordVecs = deepcopy(wordVecs)
    # Input word vecsの単語リスト
    wvVocab = set(newWordVecs.keys())
    # wvVocabとlexiconの共通単語
    loopVocab = wvVocab.intersection(set(lexicon.keys()))

    for _ in range(numIters):
        # loop through every node also in ontology (else just use data estimate)
        for word in loopVocab:
            # lexicon wordの近傍単語とwvVocabの共通単語とその個数
            wordNeighbours = set(lexicon[word]).intersection(wvVocab)
            numNeighbours = len(wordNeighbours)
            # no neighbours -> pass - use data estimate
            if numNeighbours == 0:
                continue
            """分散表現の更新手続き"""
            # the weight of the data estimate if the number of neighbours
            newVec = numNeighbours * wordVecs[word]
            # loop over neighbours and add to new vector (currently with weight 1)
            for ppWord in wordNeighbours:
                newVec += newWordVecs[ppWord]
            newWordVecs[word] = newVec/(2*numNeighbours)
    return newWordVecs

In [8]:
def similarity(v1, v2):
    n1 = np.linalg.norm(v1)
    n2 = np.linalg.norm(v2)
    return np.dot(v1, v2) / n1 / n2

### 変数設定

In [8]:
# Input word vecs -> original
input_arg = './word2vec/vectors.model'
# Lexicon file name
lexicon_arg = './lexicons/wordnet-jpn.txt'
# Num iterations
numiter_arg = 10
# Output word vecs -> retrofitting
output_arg = './out_vec.txt'

In [9]:
numIter = int(numiter_arg)

In [10]:
outFileName = output_arg

In [11]:
lexicon = read_lexicon(lexicon_arg)

In [12]:
wordVecs = read_word_vecs(input_arg)

Vectors read from: ./word2vec/vectors.model 


In [13]:
new_vec = retrofit(wordVecs, lexicon, numIter)

### retrofittingした分散表現を保存する

In [14]:
"""Enrich the word vectors using ppdb and print the enriched vectors"""
print_word_vecs(new_vec, outFileName)


Writing down the vectors in ./out_vec.txt


### 評価

In [116]:
word = 'トラ'
# path = "./fastText/model.vec"
negative = False # Falseなら似た単語を候補で上げる
threshold = 0.6 # -1なら閾値固定

In [320]:
vecs = wordVecs

In [322]:
vecs = new_vec

In [323]:
# 特定単語の類似度算出
v1 = 'アテネ'
v2 = 'ギリシャ'

if v1 not in lexicon:
    print("v1 not found error in dict")
if v2 not in lexicon:
    print("v2 not found error in dict")
    
try:
    print(similarity(vecs[v1], vecs[v2]))
except:
    print('error')

v2 not found error in dict
0.8562229551754349


In [348]:
vecs_wordVecs = wordVecs

In [349]:
vecs_new_vec = new_vec

In [373]:
# アナロジー算出
v1 = 'オタワ'
v2 = 'カナダ'
v3 = 'ストックホルム'
v4 = 'スウェーデン'

if v1 not in lexicon:
    print("v1 not found error in dict")
if v2 not in lexicon:
    print("v2 not found error in dict")
if v3 not in lexicon:
    print("v3 not found error in dict")
if v4 not in lexicon:
    print("v4 not found error in dict")
    
try:
    print('word2vec : {}'.format(similarity(vecs_wordVecs[v1] + vecs_wordVecs[v2], vecs_wordVecs[v3] + vecs_wordVecs[v4])))
except:
    print('error')

try:
    print('retrofitting : {}'.format(similarity(vecs_new_vec[v1] + vecs_new_vec[v2], vecs_new_vec[v3] + vecs_new_vec[v4])))
except:
    print('error')

word2vec : 0.48468072664976913
retrofitting : 0.4639886453578066


In [244]:
# wordの設定確認
if not word:
    raise Exception("word is missing")
    
# wordがモデルにない場合，
if word not in vecs:
    raise Exception("Sorry, this word is not registered in model.")
w_vec = vecs[word]

# ナレッジグラフにあるかどうかの確認
lexicon = read_lexicon(lexicon_arg)
if word not in lexicon:
#     raise Exception("not found error in dict")
    print("not found error in dict")

# 閾値の設定
border_positive = threshold if threshold > 0 else 0.8
border_negative = threshold if threshold > 0 else 0.3

# 候補数の設定
max_candidates = 20
candidates = {}

In [245]:
for w in vecs:
    try:
        if w_vec.shape != vecs[w].shape:
            raise Exception("size not match")
        s = similarity(w_vec, vecs[w])
    except Exception as ex:
        print(w + " is not valid word.")
        continue

    if negative and s <= border_negative:
        candidates[w] = s
        if len(candidates) % 5 == 0:
            border_negative -= 0.05
    elif not negative and s >= border_positive:
        candidates[w] = s
        if len(candidates) % 5 == 0:
            border_positive += 0.05

    if len(candidates) > max_candidates:
        break

754069 is not valid word.


In [246]:
# 類義語算出
sorted_candidates = sorted(candidates, key=candidates.get, reverse=not negative)
for c in sorted_candidates:
    print("{0}, {1}".format(c, candidates[c]))

トラ, 1.0000000000000002
虎, 0.7187464339674907
タイガー, 0.650029262972028
