In [1]:
from sklearn.datasets import fetch_20newsgroups

In [2]:
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']

In [3]:
twenty_train = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42)

In [4]:
twenty_test = fetch_20newsgroups(subset='test', categories=categories, shuffle=True, random_state=42)

#### Building a pipeline

In [10]:
from sklearn.pipeline import Pipeline
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer

In [11]:
text_clf = Pipeline([('vect', CountVectorizer()),
                     ('tfidf', TfidfTransformer()),
                     ('clf', MultinomialNB()),
                    ])

In [12]:
text_clf.fit(twenty_train.data, twenty_train.target)

Pipeline(memory=None,
     steps=[('vect', CountVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), preprocessor=None, stop_words=None,
        strip...inear_tf=False, use_idf=True)), ('clf', MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True))])

In [13]:
text_clf.score(twenty_test.data, twenty_test.target)

0.83488681757656458

In [14]:
from sklearn.linear_model import SGDClassifier

In [15]:
text_clf = Pipeline([('vect', CountVectorizer()),
                     ('tfidf', TfidfTransformer()),
                     ('clf', SGDClassifier()),
                    ])



In [16]:
text_clf.fit(twenty_train.data, twenty_train.target)

Pipeline(memory=None,
     steps=[('vect', CountVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), preprocessor=None, stop_words=None,
        strip...='l2', power_t=0.5, random_state=None,
       shuffle=True, tol=None, verbose=0, warm_start=False))])

In [17]:
text_clf.score(twenty_test.data, twenty_test.target)

0.91944074567243672

### Grid-search on Pipeline

In [18]:
from sklearn.model_selection import GridSearchCV

In [19]:
parameters = {'vect__ngram_range':[(1,1), (1,2)],
              'tfidf__use_idf':(True, False),
              'clf__alpha':(1e-2, 1e-3),
}

In [21]:
gs_clf = GridSearchCV(text_clf, parameters, n_jobs=1)

In [22]:
gs_clf.fit(twenty_train.data, twenty_train.target)

GridSearchCV(cv=None, error_score='raise',
       estimator=Pipeline(memory=None,
     steps=[('vect', CountVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), preprocessor=None, stop_words=None,
        strip...='l2', power_t=0.5, random_state=None,
       shuffle=True, tol=None, verbose=0, warm_start=False))]),
       fit_params=None, iid=True, n_jobs=1,
       param_grid={'tfidf__use_idf': (True, False), 'vect__ngram_range': [(1, 1), (1, 2)], 'clf__alpha': (0.01, 0.001)},
       pre_dispatch='2*n_jobs', refit=True, return_train_score=True,
       scoring=None, verbose=0)

In [23]:
gs_clf.score(twenty_test.data, twenty_test.target)

0.90812250332889477

In [26]:
gs_clf.predict(['god save us all'])

array([3], dtype=int64)

In [27]:
twenty_train.target_names

['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']

In [28]:
gs_clf.predict(['i am in good health'])

array([2], dtype=int64)

In [32]:
gs_clf.predict(['god me with good health'])

array([2], dtype=int64)

In [29]:
from sklearn import metrics

In [36]:
predicted = gs_clf.predict(twenty_test.data)
metrics.confusion_matrix(twenty_test.target, predicted)

array([[252,  11,  16,  40],
       [  3, 380,   2,   4],
       [  4,  36, 351,   5],
       [  5,  10,   2, 381]], dtype=int64)

In [55]:
import numpy as np
np.where(twenty_test.target != predicted)[0]

array([   8,   12,   15,   19,   36,   85,  105,  106,  117,  124,  128,
        135,  146,  174,  191,  220,  229,  230,  234,  253,  254,  257,
        288,  292,  301,  306,  318,  333,  347,  355,  364,  368,  391,
        392,  401,  411,  413,  420,  431,  433,  442,  449,  472,  485,
        499,  514,  523,  534,  535,  540,  559,  564,  580,  583,  588,
        598,  613,  671,  682,  689,  720,  729,  735,  741,  742,  744,
        746,  756,  768,  773,  786,  791,  793,  809,  810,  813,  817,
        849,  858,  874,  875,  877,  883,  889,  908,  936,  948,  954,
        956,  978,  993,  996, 1008, 1015, 1018, 1026, 1045, 1046, 1047,
       1051, 1055, 1056, 1112, 1118, 1123, 1134, 1143, 1144, 1158, 1174,
       1189, 1207, 1244, 1261, 1268, 1302, 1303, 1307, 1335, 1340, 1349,
       1350, 1352, 1353, 1354, 1358, 1369, 1377, 1384, 1393, 1419, 1436,
       1438, 1440, 1449, 1450, 1455, 1480], dtype=int64)

In [56]:
twenty_test.data[8]

