In [1]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
import numpy as np

# load the dataset
newsgroups = fetch_20newsgroups(shuffle=True, remove=('headers', 'footers', 'quotes'))
true_topics = newsgroups.target_names

# 20 topics in this dataset
nr_topics = 20

In [2]:
# only use words that appear less than 95% and at least 3 times in a document. remove all stop words.
vectorizer = TfidfVectorizer(max_df=0.95, min_df=2, stop_words='english')
vectors = vectorizer.fit_transform(newsgroups.data)
feature_names = np.array(vectorizer.get_feature_names())

# matrix containing documents in rows and words in columns
print(vectors.shape)

(11314, 39115)


In [3]:
# define non-negative matrix factorization
nmf = NMF(n_components=nr_topics, init='nndsvd', max_iter=300)

In [4]:
# factorize X into W*H using NMF
W = nmf.fit_transform(vectors)
H = nmf.components_

# iterations needed
print(nmf.n_iter_)
# reconstruction error
print(nmf.reconstruction_err_)

184
102.52151111971088


In [5]:
# nr of words per topic
nr_words = 3

print("(topics, words): " + str(H.shape) + "\n")

print("Top " + str(nr_words) + " words for each topic:")

for id_topic, topic in enumerate(H):
    sorted_topic = topic.argsort()
    important_words = sorted_topic[-nr_words:]
    print("Topic " + str(id_topic) + ": " + ', '.join(feature_names[important_words]))

(topics, words): (20, 39115)

Top 3 words for each topic:
Topic 0: like, don, just
Topic 1: dos, file, windows
Topic 2: bible, jesus, god
Topic 3: chastity, dsl, geb
Topic 4: encryption, chip, key
Topic 5: hard, disk, drive
Topic 6: 10, sale, 00
Topic 7: advance, mail, thanks
Topic 8: gun, government, people
Topic 9: monitor, video, card
Topic 10: games, team, game
Topic 11: cars, bike, car
Topic 12: application, motif, window
Topic 13: shuttle, nasa, space
Topic 14: anybody, know, does
Topic 15: jews, israeli, israel
Topic 16: turkish, armenians, armenian
Topic 17: controller, ide, scsi
Topic 18: ftp, com, edu
Topic 19: mac, software, use


In [6]:
print("True topics:\n")

for i in range(len(true_topics)):
    print(true_topics[i])

True topics:

alt.atheism
comp.graphics
comp.os.ms-windows.misc
comp.sys.ibm.pc.hardware
comp.sys.mac.hardware
comp.windows.x
misc.forsale
rec.autos
rec.motorcycles
rec.sport.baseball
rec.sport.hockey
sci.crypt
sci.electronics
sci.med
sci.space
soc.religion.christian
talk.politics.guns
talk.politics.mideast
talk.politics.misc
talk.religion.misc


In [7]:
print(W.shape)

(11314, 20)


In [8]:
art = 0
print("Article " + str(art) + " is predicted to belong to topic: " + str(np.argmax(W[art])))

Article 0 is predicted to belong to topic: 11


In [9]:
newsgroups.filenames[art]

'C:\\Users\\lanze\\scikit_learn_data\\20news_home\\20news-bydate-train\\rec.autos\\102994'