In [1]:
from sklearn.datasets import fetch_20newsgroups
from pprint import pprint
from sklearn.feature_extraction.text import CountVectorizer
from gensim import matutils
from gensim.models.ldamodel import LdaModel
from gensim.corpora import Dictionary
from time import time

Использую стандартный датасет **20newsgroups**.

In [2]:
newsgroups = fetch_20newsgroups()

In [3]:
pprint(list(newsgroups.target_names))

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']


In [4]:
newsgroups.filenames.shape

(11314,)

Сокращу немного датасет, оставив 5 различных тем.

In [5]:
cats = ['soc.religion.christian', 'talk.religion.misc', 'sci.space', 'rec.motorcycles', 'comp.graphics']
newsgroups = fetch_20newsgroups(subset='train', categories=cats)

In [6]:
newsgroups.filenames.shape

(2751,)

In [7]:
newsgroups.data[0]

u"From: henry@zoo.toronto.edu (Henry Spencer)\nSubject: Re: Solar Sail Data\nOrganization: U of Toronto Zoology\nLines: 10\n\nIn article <1qk4qf$mf8@male.EBay.Sun.COM> almo@packmind.EBay.Sun.COM writes:\n>Hey!? What happened to the solar sail race that was supposed to be\n>for Columbus+500?\n\nThere was a recession, and none of the potential entrants could raise any\nmoney.  The race organizers were actually supposed to be handling part of\nthe fundraising, but the less said about that the better.\n-- \nAll work is one man's work.             | Henry Spencer @ U of Toronto Zoology\n                    - Kipling           |  henry@zoo.toronto.edu  utzoo!henry\n"

In [8]:
vectorizer = CountVectorizer(stop_words='english', analyzer='word', min_df=2, max_df=0.4)
X = vectorizer.fit_transform(newsgroups.data)

In [9]:
X[1]

<1x20971 sparse matrix of type '<type 'numpy.int64'>'
	with 64 stored elements in Compressed Sparse Row format>

In [10]:
vocab = vectorizer.get_feature_names()

In [11]:
id2word = dict([(i, s) for i, s in enumerate(vocab)])
corpus = matutils.Sparse2Corpus(X, documents_columns=False)

In [12]:
start = time()
lda = LdaModel(corpus, id2word=id2word, num_topics=5, alpha='auto', update_every=1, passes=30)
print 'Evaluation time: {}'.format((time() - start) / 60)

Evaluation time: 3.76850238244


In [13]:
lda.print_topics()

[(0,
  u'0.015*god + 0.009*jesus + 0.007*people + 0.005*think + 0.005*christian + 0.004*bible + 0.004*christ + 0.004*don + 0.004*know + 0.004*believe'),
 (1,
  u'0.008*image + 0.006*graphics + 0.004*data + 0.004*images + 0.004*software + 0.004*available + 0.003*___ + 0.003*ca + 0.003*ftp + 0.003*file'),
 (2,
  u'0.005*people + 0.005*does + 0.005*don + 0.005*just + 0.004*know + 0.004*like + 0.004*university + 0.004*think + 0.004*science + 0.004*say'),
 (3,
  u'0.018*space + 0.008*nasa + 0.005*launch + 0.004*gov + 0.004*earth + 0.004*moon + 0.003*orbit + 0.003*satellite + 0.003*shuttle + 0.003*year'),
 (4,
  u'0.007*bike + 0.007*posting + 0.006*nntp + 0.006*host + 0.005*like + 0.005*dod + 0.005*university + 0.005*ca + 0.004*don + 0.004*just')]

Понятно, что стоило сделать нормальную предварительную очистку от специфичных стоп-слов, но даже без неё в общем-то темы угадываются. Хотя, конечно, не идеально. Но алгоритм работал всего 3.7 минут, можно было увеличить время работы, увеличив число проходов по документам (passes).

In [14]:
query = 'scientists explore new star'
query = query.split()

In [15]:
special_id2word = Dictionary()
_ = special_id2word.merge_with(id2word)

In [16]:
query = special_id2word.doc2bow(query)
query

[(7939, 1), (13255, 1), (16841, 1), (17977, 1)]

In [17]:
lda[query]

[(0, 0.015438042430812933),
 (1, 0.0137238189385505),
 (2, 0.011630400426774911),
 (3, 0.93976944218015324),
 (4, 0.01943829602370855)]

In [18]:
sorted_results = list(sorted(lda[query], key=lambda x: x[1]))
print(sorted_results[0])
print(sorted_results[-1])

(2, 0.011630400416873547)
(3, 0.93976944219711334)


In [19]:
lda.print_topic(sorted_results[-1][0])

u'0.018*space + 0.008*nasa + 0.005*launch + 0.004*gov + 0.004*earth + 0.004*moon + 0.003*orbit + 0.003*satellite + 0.003*shuttle + 0.003*year'

Сохраню модель:

In [20]:
lda.save('newsgroups.lda')
#lda = gensim.models.LdaModel.load('newsgroups.lda')