In [1]:
import numpy as np
from sklearn.datasets import fetch_20newsgroups

newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))

In [2]:
from sklearn.feature_extraction.stop_words import ENGLISH_STOP_WORDS
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer(stop_words = ENGLISH_STOP_WORDS).fit(newsgroups_train.data)
X_train_vectors = vectorizer.transform(newsgroups_train.data)

In [3]:
from sklearn.naive_bayes import ComplementNB

clf = ComplementNB(alpha=0.9, class_prior=None, fit_prior=False, norm=False).fit(X_train_vectors, newsgroups_train.target)

In [4]:
def show_top10(classifier, vectorizer, categories):
    feature_names = np.asarray(vectorizer.get_feature_names())
    for i, category in enumerate(categories):
        top10 = np.argsort(classifier.feature_log_prob_[i])[-10:]
        print("%s: %s" % (category, " ".join(feature_names[top10])))
        
show_top10(clf, vectorizer, newsgroups_train.target_names)

alt.atheism: bevans geocentrism circulus lewd abiliy proselytizers propter circumstantial bethulah elee
comp.graphics: epsf 020637 epsi ove outwards criiterion outrunning outputing cricket visualizer
comp.os.ms-windows.misc: auhl7 om9xax 82c607 82bbzt _supercharging hix 3958784 auh hitr 3iirlj103j1
comp.sys.ibm.pc.hardware: suggeted me2 allister assisatnce 2350 libc 8800cs 02h drdos6 persnickity
comp.sys.mac.hardware: poweropen powerpcs eeeee 5el3 1722 10pnt powerstrip ichips trafa catchup
comp.windows.x: 8vao est5edt popen_xphigs llat popen_ws fractionally nassestr wellorganized bursty n86pl
misc.forsale: thuan ft6000 22bis neville wheelwriter heise thums zmed16 thoren availble
rec.autos: olof citroen olde yjs vavau gearboxes bellevue gearshift opdbs tercel
rec.motorcycles: lotta spooge xz550 sportmax springers sprocket sprockets squeak sported me77
rec.sport.baseball: tantrum tanstaafl riles huckabay slyke dhenderson 30s drayton scoresheets ma_ind25
rec.sport.hockey: looming dade vir

In [5]:
from sklearn.metrics import classification_report

newsgroups_test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))
X_test_vectors = vectorizer.transform(newsgroups_test.data)
predicts = clf.predict(X_test_vectors)
print(classification_report(newsgroups_test.target, predicts,
                            target_names=newsgroups_test.target_names))

                          precision    recall  f1-score   support

             alt.atheism       0.31      0.42      0.36       319
           comp.graphics       0.73      0.72      0.72       389
 comp.os.ms-windows.misc       0.70      0.61      0.65       394
comp.sys.ibm.pc.hardware       0.64      0.70      0.67       392
   comp.sys.mac.hardware       0.76      0.73      0.74       385
          comp.windows.x       0.82      0.79      0.80       395
            misc.forsale       0.77      0.74      0.75       390
               rec.autos       0.82      0.75      0.78       396
         rec.motorcycles       0.84      0.77      0.81       398
      rec.sport.baseball       0.92      0.83      0.88       397
        rec.sport.hockey       0.84      0.94      0.89       399
               sci.crypt       0.75      0.81      0.78       396
         sci.electronics       0.71      0.55      0.62       393
                 sci.med       0.83      0.81      0.82       396
         