In [75]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import GridSearchCV
from sklearn import metrics
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.naive_bayes import MultinomialNB
import numpy as np
import pickle

In [27]:
fetch_20newsgroups()
twenty_train = fetch_20newsgroups(subset='train', shuffle=True, random_state=42)

In [28]:
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(twenty_train.data)
count_vect.vocabulary_.get(u'algorithm')

27366

In [31]:
tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
X_train_tfidf.shape

(11314, 130107)

In [32]:
clf = MultinomialNB().fit(X_train_tfidf, twenty_train.target)

In [33]:
docs_new = ['God is love', 'OpenGL on the GPU is fast']
X_new_counts = count_vect.transform(docs_new)
X_new_tfidf = tfidf_transformer.transform(X_new_counts)

In [34]:
predicted = clf.predict(X_new_tfidf)

In [35]:
for doc, category in zip(docs_new, predicted):
    print('%r => %s' % (doc, twenty_train.target_names[category]))

'God is love' => soc.religion.christian
'OpenGL on the GPU is fast' => rec.autos


In [36]:
from sklearn.pipeline import Pipeline
text_clf = Pipeline([('vect', CountVectorizer()),
                      ('tfidf', TfidfTransformer()),
                      ('clf', MultinomialNB()),
])

In [40]:
text_clf.fit(twenty_train.data, twenty_train.target)  
#Pipeline(...)

Pipeline(memory=None,
     steps=[('vect', CountVectorizer(analyzer=u'word', binary=False, decode_error=u'strict',
        dtype=<type 'numpy.int64'>, encoding=u'utf-8', input=u'content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), preprocessor=None, stop_words=None,
        st...False,
         use_idf=True)), ('clf', MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True))])

In [41]:
twenty_test = fetch_20newsgroups(subset='test',
     shuffle=True, random_state=42)
docs_test = twenty_test.data
predicted = text_clf.predict(docs_test)
np.mean(predicted == twenty_test.target)

0.7738980350504514

Now for the training set.

In [45]:
from sklearn.linear_model import SGDClassifier
text_clf = Pipeline([('vect', CountVectorizer()),
                      ('tfidf', TfidfTransformer()),
                      ('clf', SGDClassifier(loss='hinge', penalty='l2',
                                            alpha=1e-3, random_state=42,
                                            max_iter=5, tol=None)),
 ])
text_clf.fit(twenty_train.data, twenty_train.target)  
#Pipeline(...)
predicted = text_clf.predict(docs_test)
np.mean(predicted == twenty_test.target)        

0.82381837493361654

In [47]:
print(metrics.classification_report(twenty_test.target, predicted,
     target_names=twenty_test.target_names))
metrics.confusion_matrix(twenty_test.target, predicted)

                          precision    recall  f1-score   support

             alt.atheism       0.73      0.72      0.72       319
           comp.graphics       0.80      0.70      0.74       389
 comp.os.ms-windows.misc       0.73      0.76      0.75       394
comp.sys.ibm.pc.hardware       0.71      0.70      0.70       392
   comp.sys.mac.hardware       0.83      0.81      0.82       385
          comp.windows.x       0.83      0.77      0.80       395
            misc.forsale       0.84      0.90      0.87       390
               rec.autos       0.92      0.89      0.91       396
         rec.motorcycles       0.92      0.96      0.94       398
      rec.sport.baseball       0.89      0.90      0.89       397
        rec.sport.hockey       0.88      0.99      0.93       399
               sci.crypt       0.83      0.96      0.89       396
         sci.electronics       0.83      0.60      0.70       393
                 sci.med       0.87      0.86      0.86       396
         

array([[230,   0,   0,   1,   0,   2,   1,   0,   1,   3,   0,   2,   1,
         11,   5,  41,   2,   8,   1,  10],
       [  3, 272,  21,  11,   7,  25,   4,   1,   3,   4,   3,   9,   4,
          3,   9,   3,   2,   4,   0,   1],
       [  1,   9, 301,  26,  10,  13,   2,   0,   0,   7,   2,   9,   1,
          2,   7,   1,   0,   1,   1,   1],
       [  3,   9,  27, 274,  22,   3,  12,   3,   4,   1,   1,   4,  19,
          2,   4,   0,   1,   2,   1,   0],
       [  0,   5,   8,  26, 313,   2,   9,   0,   1,   4,   1,   3,   6,
          1,   1,   0,   2,   1,   2,   0],
       [  1,  29,  39,   1,   2, 304,   2,   0,   1,   1,   1,   3,   1,
          1,   7,   1,   1,   0,   0,   0],
       [  0,   2,   0,  14,   4,   0, 352,   6,   1,   1,   2,   1,   2,
          2,   2,   0,   1,   0,   0,   0],
       [  1,   1,   0,   2,   1,   0,  10, 354,   8,   2,   0,   0,  10,
          0,   3,   0,   3,   0,   1,   0],
       [  0,   0,   0,   1,   0,   0,   4,   6, 384,   2,   0,  

In [49]:
parameters = {'vect__ngram_range': [(1, 1), (1, 2)],
              'tfidf__use_idf': (True, False),
              'clf__alpha': (1e-2, 1e-3),
}

In [50]:
gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1)

In [51]:
gs_clf = gs_clf.fit(twenty_train.data[:400], twenty_train.target[:400])

In [52]:
twenty_train.target_names[gs_clf.predict(['God is love'])[0]]

'soc.religion.christian'

In [53]:
gs_clf.best_score_

0.60250000000000004

In [54]:
for param_name in sorted(parameters.keys()):
    print("%s: %r" % (param_name, gs_clf.best_params_[param_name]))

clf__alpha: 0.001
tfidf__use_idf: True
vect__ngram_range: (1, 1)


In [73]:
def test_new_article(txt):
    X_new_counts1 = count_vect.transform(txt)
    X_new_tfidf1 = tfidf_transformer.transform(X_new_counts1)
    ##returns an array of predictions
    preds = clf.predict(X_new_tfidf1)
    return [twenty_train.target_names[i] for i in preds]

In [81]:
b = test_new_article(['ball stick run fast, exciting, crowds, fans, speed'])

In [77]:
## pickle the model for later
with open('trained_model.pkl', 'wb') as f:
    pickle.dump(clf, f)

In [80]:
# and later you can load it
with open('trained_model.pkl', 'rb') as f:
    clf = pickle.load(f)

In [86]:
b[0]

'rec.sport.baseball'