In [176]:
import re
import datetime
import numpy as np
from gensim import models
from gensim.models.doc2vec import TaggedDocument
from sklearn.datasets import fetch_20newsgroups

In [177]:
from pkg_resources import get_distribution
import platform
print("python", platform.python_version())
print("")
libs = ["numpy", "gensim", "scikit-learn"]
for lib in libs:
    version = get_distribution(lib).version
    print(lib, version)

python 3.5.2

numpy 1.13.1
gensim 2.3.0
scikit-learn 0.18.2


In [178]:
# データ取得

categories = ["alt.atheism", "rec.sport.baseball", "sci.electronics"]
train = fetch_20newsgroups(subset="train", remove=("headers", "footers", "quotes"), categories=categories)
test = fetch_20newsgroups(subset="test", remove=("headers", "footers", "quotes"), categories=categories)
print(len(train.data), len(train.target))
print(len(test.data), len(test.target))

1668 1668
1109 1109


In [179]:
# 文章を単語リストに分解する関数
def sentence2words(sentence):
    stop_words = ["a"]
    sentence = sentence.lower() # 小文字化
    sentence = sentence.replace("\n", " ") # 改行削除
    sentence = re.sub(re.compile(r"[!-\/:-@[-`{-~]"), " ", sentence) # 記号をスペースに置き換え
    sentence = sentence.split(" ") # スペースで区切る
    words = []
    for word in sentence:
        if (re.compile(r"^.*[0-9]+.*$").fullmatch(word) is None) and (word not in stop_words) and (len(word) > 0): # 数字が含まれるもの、ストップワードに含まれるものは除外
            words.append(word)
    return words

In [180]:
# Doc2Vecに読み込ませるTaggedDocumentを用意する

training_docs = []

mapping = {}
for i, (doc, target) in enumerate(zip(train.data, train.target)):
    words = sentence2words(doc)
    training_docs.append(TaggedDocument(words=words, tags=[i])) # ドキュメント、ドキュメント番号
    mapping[i] = (doc, target) # ドキュメント番号、ドキュメント、カテゴリ番号
print(len(training_docs), len(mapping))

1668 1668


In [181]:
# Doc2Vec学習

EPOCH_NUM = 50
ALPHA = .025
ALPHA_DECREASE = .002

model = models.Doc2Vec(dm=1, min_count=1, size=300, alpha=ALPHA, min_alpha=ALPHA) # dm=1 => dmpv, dm!=1 => DBoW
model.build_vocab(training_docs)

st = datetime.datetime.now()
for epoch in range(EPOCH_NUM):
    model.train(training_docs, total_examples=model.corpus_count, epochs=model.iter)
    model.alpha -= (ALPHA - ALPHA_DECREASE) / (EPOCH_NUM - 1)
    model.min_alpha = model.alpha
    if (epoch+1)%10 == 0:
        ed = datetime.datetime.now()
        print('epoch:\t{}\talpha:\t{}\ttime:\t{}'.format(epoch+1, model.alpha, ed-st))
        st = datetime.datetime.now()

epoch:	10	alpha:	0.020306122448979586	time:	0:00:20.367327
epoch:	20	alpha:	0.01561224489795917	time:	0:00:19.669578
epoch:	30	alpha:	0.010918367346938754	time:	0:00:20.083310
epoch:	40	alpha:	0.006224489795918341	time:	0:00:19.606074
epoch:	50	alpha:	0.0015306122448979342	time:	0:00:19.231122


In [182]:
# モデルの保存・読み込み

#model.save("doc2vec.model")
#model = models.Doc2Vec.load("doc2vec.model")

In [183]:
# Tagを入力して文章ベクトルを取得する

print(len(model.docvecs[0]))
model.docvecs[0]

300


array([-0.39049387,  1.55187416, -0.22572853, -1.11344874, -0.39320791,
       -0.4986077 ,  0.40878046,  0.99784631,  0.08340978, -1.25524998,
        0.49508616, -0.20194784, -0.19500758, -1.19082129, -0.71139616,
        0.45409137, -1.43012714,  1.06278861, -1.40704489, -1.94125473,
        0.55191243,  1.49978495,  0.39280772, -1.67011893, -0.44794005,
        1.26614821,  0.55800962, -0.2518321 ,  1.13102543,  0.68995625,
        0.7582413 , -0.21434598, -1.22841716, -0.52394909,  2.47317338,
        1.32399714, -0.70593667,  0.19710892,  1.6270299 , -0.22111121,
        1.21336043,  0.57109952,  0.97255743,  0.10906313, -0.67421776,
        0.10775118,  2.19833016,  1.5443145 , -1.04312134, -0.69988644,
       -0.82725644, -0.73146671,  2.19396734, -0.89467394,  0.12018394,
        1.65639269,  0.6621207 ,  0.31588989, -0.80312634, -0.51586562,
       -0.50761926,  1.89072657, -0.27219382, -0.09268516,  0.99893546,
       -0.07125062,  0.40644982,  0.17609282, -0.18786068, -0.62

In [184]:
# Tagを指定して文章間の類似度を計算する

model.docvecs.similarity(0, 1)

0.28919744734244612

In [185]:
# Tag指定をして似ているドキュメントを検索

results = model.docvecs.most_similar(3, topn=10)
for r in results:
    print(r)
print("========================================")
print(mapping[3][1])
print(mapping[3][0])
print("========================================")
print(mapping[926][1])
print(mapping[926][0])

(926, 0.7829058170318604)
(200, 0.4949866235256195)
(678, 0.4677603244781494)
(1080, 0.4428390860557556)
(657, 0.4364931881427765)
(1303, 0.432651162147522)
(394, 0.4254303574562073)
(546, 0.4241424798965454)
(747, 0.42403703927993774)
(425, 0.4236903786659241)
1
} The roar at Michigan and Trumbull should be loader than ever this year.  With
} Mike Illitch at the head and Ernie Harwell back at the booth, the tiger bats
} will bang this summer.  Already they have scored 20 runs in two games and with
} Fielder, Tettleton, and Deer I think they can win the division.  No pitching!
} Bull!  Gully, Moore, Wells, and Krueger make up a decent staff that will keep
} the team into many games.  Then there is Henneman to close it out.  Watch out
} Boston, Toronto, and Baltimore - the Motor City Kittys are back.

nice woofing (or should i say meowing?).
and yes, the Tiggers are a fun, exciting team that i would pay to see.
but last year, they went 75-87. this year, their offense is essentially
the 

In [186]:
# 似ている単語を検索

model.most_similar("sports")

[('service', 0.329574853181839),
 ('philadelphia', 0.29402637481689453),
 ('anthem', 0.2873937785625458),
 ('required', 0.28282034397125244),
 ('wip', 0.2713123857975006),
 ('contemporary', 0.26715087890625),
 ('formatted', 0.2650032043457031),
 ('fr', 0.26084935665130615),
 ('broadcast', 0.25425955653190613),
 ('jeruselem', 0.25247570872306824)]

In [193]:
# 文章を入力して計算される文章ベクトルを取得する

test_idx = 0
doc, target = sentence2words(test.data[test_idx]), test.target[test_idx]
print(len(vec))
vec = model.infer_vector(doc)
vec

300


array([  5.27290702e-02,   3.35803479e-01,  -8.99047777e-02,
        -1.04133204e-01,   8.18891823e-02,   1.23454474e-01,
         1.04237244e-01,   2.56724119e-01,  -8.60964507e-03,
        -4.56707597e-01,   1.94484606e-01,   1.58296704e-01,
        -2.11743087e-01,   1.16416693e-01,  -2.94152647e-01,
         3.18214819e-02,  -2.15628281e-01,  -1.19733222e-01,
         7.66548157e-01,  -8.12229961e-02,  -2.78911918e-01,
        -4.33014989e-01,   2.44868144e-01,   3.52567136e-02,
         7.45100155e-02,  -5.90882301e-02,   1.65814221e-01,
         3.93279940e-01,   3.04908901e-01,  -1.86405867e-01,
         1.17939569e-01,  -3.34425330e-01,  -1.57517597e-01,
         1.53233737e-01,  -1.02317473e-02,   2.49742977e-02,
        -2.35311717e-01,   6.14254456e-03,   1.61509395e-01,
        -2.51031250e-01,   2.07383990e-01,  -1.08252481e-01,
         1.73451632e-01,  -2.30455935e-01,  -5.71632862e-01,
         1.47743344e-01,   4.79872636e-02,  -9.57493186e-02,
         5.58937676e-02,

In [226]:
# 文章を入力して似ているドキュメントを検索

test_idx = 28
doc, target = sentence2words(test.data[test_idx]), test.target[test_idx]
vec = model.infer_vector(doc)
results = model.docvecs.most_similar([vec])
for r in results:
    print(r)
print("========================================")
print(test.target[test_idx])
print(test.data[test_idx])
print("========================================")
print(mapping[1335][1])
print(mapping[1335][0])

(1335, 0.8738955855369568)
(391, 0.5304422378540039)
(1206, 0.46154117584228516)
(24, 0.4458889365196228)
(1007, 0.44474029541015625)
(337, 0.4394139349460602)
(687, 0.4381486475467682)
(223, 0.43766099214553833)
(1203, 0.4349975883960724)
(1097, 0.43349671363830566)
0
Archive-name: atheism/resources
Alt-atheism-archive-name: resources
Last-modified: 5 April 1993
Version: 1.1

                              Atheist Resources

                      Addresses of Atheist Organizations

                                     USA

FREEDOM FROM RELIGION FOUNDATION

Darwin fish bumper stickers and assorted other atheist paraphernalia are
available from the Freedom From Religion Foundation in the US.

Write to:  FFRF, P.O. Box 750, Madison, WI 53701.
Telephone: (608) 256-8900

EVOLUTION DESIGNS

Evolution Designs sell the "Darwin fish".  It's a fish symbol, like the ones
Christians stick on their cars, but with feet and the word "Darwin" written
inside.  The deluxe moulded 3D plastic fish is $4.9

In [227]:
# 文章間のベクトルを計算してドキュメント検索

model.docvecs.most_similar(positive=[3], negative=[926], topn=10)

[(708, 0.1411304622888565),
 (1661, 0.13870877027511597),
 (812, 0.13440994918346405),
 (1436, 0.13286656141281128),
 (195, 0.11860021203756332),
 (1511, 0.11833393573760986),
 (1134, 0.11139731854200363),
 (1527, 0.11136351525783539),
 (283, 0.10770661383867264),
 (981, 0.09978442639112473)]