Code modified from a scikit-learn tutorial on text classification with 20 Newsgroups: https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html

In [None]:
from sklearn.datasets import fetch_20newsgroups

Load the 20 Newsgroups training set. There are twenty categories in total (listed in the comment).

In [None]:
'''
categories = ["comp.graphics", "comp.os.ms-windows.misc", "comp.sys.ibm.pc.hardware",
              "comp.sys.mac.hardware", "comp.windows.x", "rec.autos", "rec.motorcycles",
              "rec.sport.baseball", "rec.sport.hockey", "sci.crypt", "sci.electronics",
              "sci.med", "sci.space", "misc.forsale", "talk.politics.misc",
              "talk.politics.guns", "talk.politics.mideast", "talk.religion.misc",
              "alt.atheism", "soc.religion.christian"]
'''

#Load the training set
twenty_train = fetch_20newsgroups(subset="train",
                                  categories=None, shuffle=True, random_state=42)

print(twenty_train.target_names)
print(len(twenty_train.data))

['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']
11314


According to the scikit-learn documentation, the variable `twenty_train` is a dictionary-like object, which notably[link text](https://) contains:


*   `data`: A list where each entry holds the text for each document
*   `target`: The list of target labels (the labels are in the form of indices)
*   `target_names`: A list of the names of the target classes (the indices of this array match up with the labels in `target`)




In [None]:
print(twenty_train.data[0]) #Text of the first document
print(twenty_train.target[0]) #Label of the first document (as an index)
print(twenty_train.target_names[twenty_train.target[0]]) #Label (in text form)

From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.umd.edu
Organization: University of Maryland, College Park
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----





7
rec.autos


We now get word counts for our text documents by using sklearn's `CountVectorizer`. It holds the word counts in a sparse matrix (since the vocbulary is large and the vast majority of entries in each row will be zero).

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
count_vect = CountVectorizer() #Initialize CountVectorizer object
X_train = count_vect.fit_transform(twenty_train.data) #Apply to data
X_train.shape #(number of instances, vocabulary size)

(11314, 130107)

**First attempt: We train a Naive Bayes model on our counts from the training set.**

In [None]:
from sklearn.naive_bayes import MultinomialNB
clf = MultinomialNB().fit(X_train, twenty_train.target) #Training happens here

Let's test out Naive Bayes on sample inputs.

In [None]:
docs_new = ["I just scored a home run.", "That's a nice set of wheels."]
X_new_counts = count_vect.transform(docs_new)

predicted = clf.predict(X_new_counts)

for doc, category in zip(docs_new, predicted): 
  print("TEXT: %r \t PREDICTION: %s" % (doc, twenty_train.target_names[category]))

TEXT: 'I just scored a home run.' 	 PREDICTION: rec.sport.baseball
TEXT: "That's a nice set of wheels." 	 PREDICTION: rec.autos


Let's see how our model does on the test set. Normally, if we wanted to make manual adjustments to our model, we would first use a validation set. For simplicity (and due to time constraints), we skip that step.

In [None]:
import numpy as np

#Bring in the test set
twenty_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=42)
docs_test = twenty_test.data
X_test = count_vect.transform(twenty_test.data)
predicted = clf.predict(X_test)
np.mean(predicted == twenty_test.target)

0.7728359001593202

77% accuracy is really impressive, especially when we consider that there are 20 classes!

To get a better idea of where our model is struggling the most, we can print out a more detailed report which includes precision, recall, and a confusion matrix.

In [None]:
from sklearn import metrics
print(metrics.classification_report(twenty_test.target, predicted,
                                    target_names=twenty_test.target_names))

metrics.confusion_matrix(twenty_test.target, predicted)

                          precision    recall  f1-score   support

             alt.atheism       0.79      0.77      0.78       319
           comp.graphics       0.67      0.74      0.70       389
 comp.os.ms-windows.misc       0.20      0.00      0.01       394
comp.sys.ibm.pc.hardware       0.56      0.77      0.65       392
   comp.sys.mac.hardware       0.84      0.75      0.79       385
          comp.windows.x       0.65      0.84      0.73       395
            misc.forsale       0.93      0.65      0.77       390
               rec.autos       0.87      0.91      0.89       396
         rec.motorcycles       0.96      0.92      0.94       398
      rec.sport.baseball       0.96      0.87      0.91       397
        rec.sport.hockey       0.93      0.96      0.95       399
               sci.crypt       0.67      0.95      0.78       396
         sci.electronics       0.79      0.66      0.72       393
                 sci.med       0.87      0.82      0.85       396
         

array([[245,   0,   0,   1,   0,   1,   0,   0,   1,   0,   2,   1,   1,
          2,   2,  41,   2,  11,   5,   4],
       [  1, 287,   0,  12,   4,  31,   1,   0,   0,   1,   0,  26,   5,
          2,   8,   2,   2,   1,   6,   0],
       [  2,  55,   1, 134,  13, 112,   2,   0,   1,   3,   1,  31,   4,
          4,   8,   5,   2,   1,  14,   1],
       [  0,  11,   1, 300,  15,  11,   3,   5,   0,   0,   1,  11,  23,
          0,   5,   0,   1,   2,   3,   0],
       [  0,  12,   1,  22, 289,   5,   3,   5,   1,   1,   0,  14,  10,
          3,   3,   1,   4,   2,   9,   0],
       [  1,  25,   2,  11,   1, 332,   0,   0,   0,   0,   0,  13,   0,
          2,   4,   1,   2,   1,   0,   0],
       [  0,   6,   0,  35,  17,   3, 253,  16,   4,   1,   4,   6,  16,
          7,   6,   2,   5,   4,   5,   0],
       [  0,   1,   0,   2,   0,   0,   4, 360,   3,   2,   2,   3,   0,
          0,   4,   0,   4,   2,   9,   0],
       [  0,   0,   0,   1,   0,   0,   2,  13, 365,   0,   0,  

**Second attempt: Let's see what we can do with a different model. We will repeat what we have done above with the support vector machine (SVM) model.**

In [None]:
from sklearn.linear_model import SGDClassifier
clf = SGDClassifier().fit(X_train, twenty_train.target) #Training happens here

In [None]:
docs_new = ["I just scored a home run.", "That's a nice set of wheels."]
X_new_counts = count_vect.transform(docs_new)

predicted = clf.predict(X_new_counts)

for doc, category in zip(docs_new, predicted): 
  print("TEXT: %r \t PREDICTION: %s" % (doc, twenty_train.target_names[category]))

TEXT: 'I just scored a home run.' 	 PREDICTION: rec.sport.baseball
TEXT: "That's a nice set of wheels." 	 PREDICTION: misc.forsale


In [None]:
import numpy as np

#Bring in the test set
twenty_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=42)
docs_test = twenty_test.data
X_test = count_vect.transform(twenty_test.data)
predicted = clf.predict(X_test)
np.mean(predicted == twenty_test.target)

0.766728624535316

In [None]:
from sklearn import metrics
print(metrics.classification_report(twenty_test.target, predicted,
                                    target_names=twenty_test.target_names))

metrics.confusion_matrix(twenty_test.target, predicted)

                          precision    recall  f1-score   support

             alt.atheism       0.73      0.72      0.73       319
           comp.graphics       0.59      0.73      0.65       389
 comp.os.ms-windows.misc       0.66      0.55      0.60       394
comp.sys.ibm.pc.hardware       0.61      0.64      0.63       392
   comp.sys.mac.hardware       0.68      0.79      0.73       385
          comp.windows.x       0.81      0.68      0.74       395
            misc.forsale       0.79      0.80      0.80       390
               rec.autos       0.77      0.86      0.81       396
         rec.motorcycles       0.93      0.91      0.92       398
      rec.sport.baseball       0.87      0.86      0.86       397
        rec.sport.hockey       0.94      0.90      0.92       399
               sci.crypt       0.82      0.90      0.86       396
         sci.electronics       0.71      0.66      0.68       393
                 sci.med       0.81      0.73      0.77       396
         

array([[230,   7,   0,   3,   3,   1,   1,   0,   1,   0,   1,   5,   3,
          4,   5,  19,   7,   4,   2,  23],
       [  5, 283,  13,   6,  15,  20,   3,   3,   3,   4,   0,   7,  12,
          4,   0,   2,   3,   2,   1,   3],
       [  0,  35, 218,  56,  22,  18,   2,   7,   1,   3,   0,   8,   6,
          3,   1,   1,   6,   1,   3,   3],
       [  0,  17,  22, 252,  35,   3,  15,   4,   1,   1,   0,   8,  25,
          3,   0,   0,   5,   0,   1,   0],
       [  1,  11,  12,  18, 306,   0,   6,   4,   0,   6,   0,   0,  13,
          3,   0,   0,   5,   0,   0,   0],
       [  1,  44,  30,   8,   5, 270,   4,   4,   2,   2,   2,   3,   6,
          4,   1,   3,   2,   1,   2,   1],
       [  1,   4,   5,   9,  12,   0, 313,  13,   2,   1,   4,   2,  12,
          5,   2,   2,   3,   0,   0,   0],
       [  1,   4,   4,   4,   3,   1,  12, 341,   7,   2,   0,   1,   9,
          2,   1,   0,   2,   0,   2,   0],
       [  1,   3,   0,   1,   1,   0,   4,  12, 361,   2,   2,  