Simple baseline to compare spaCy document classification models against

In [19]:
from __future__ import division 
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_selection import chi2, SelectKBest
from sklearn.model_selection import GridSearchCV, KFold, cross_val_score, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import binarize 
from sklearn import metrics 
from itertools import islice
from pathlib import Path 
import pandas as pd 
import re
import string

In [8]:
DATA_DIR = Path('../../data/wiki10')
TEXT_DIR = DATA_DIR / 'text' 
LABELS_PATH = DATA_DIR / 'clf0-singlelabel.csv'

Load data

In [22]:
df = pd.read_csv(CLF_LABELS_PATH)
y = df.tag
texts = [TEXT_DIR.joinpath(id).read_text() for id in df.id]

In [14]:
def tokenizer(text):
    return re.sub(f'([{string.punctuation}])', r' \1 ', text.lower()).split()

Creating sparse feature matrix and split into test/train sets (also split input text for easy reference/model debugging)

In [72]:
cvec = CountVectorizer(tokenizer=tokenizer, min_df=3, stop_words='english')
X = cvec.fit_transform(texts)

X_train, X_test, y_train, y_test, text_train, text_test = train_test_split(X, y, texts, test_size = 0.3, random_state = 0)

Define simple training pipeline, using chi-squared feature selection

In [65]:
pipeline = Pipeline([
    ('kbest_feat', SelectKBest(chi2, k=2500)), 
    ('classifier', MultinomialNB()) 
    ])

pipeline.fit(X_train, y_train)

Pipeline(memory=None,
     steps=[('kbest_feat', SelectKBest(k=2500, score_func=<function chi2 at 0x7f2fdd580d90>)), ('classifier', MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True))])

In [66]:
labels = pipeline.classes_

In [67]:
pred_y_test = pipeline.predict(X_test) 

print(metrics.classification_report(y_test, pred_y_test))

             precision    recall  f1-score   support

       food       0.98      0.98      0.98        97
       math       0.99      0.95      0.97       100
      music       0.96      0.94      0.95        83
   politics       0.91      0.95      0.93        93
   religion       0.92      0.95      0.94        85
   software       0.93      0.91      0.92        82

avg / total       0.95      0.95      0.95       540



In [68]:
def confusion_df(y_true, y_pred, labels):
    confusion_df = pd.DataFrame(metrics.confusion_matrix(y_true, y_pred, labels=labels), columns=labels, index=labels)
    confusion_df.columns.name = 'predicted'
    confusion_df.index.name = 'actual'
    return confusion_df

In [69]:
confusion_df(y_test, pred_y_test, labels)

predicted,food,math,music,politics,religion,software
actual,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
food,95,0,0,2,0,0
math,0,95,1,2,1,1
music,0,0,78,0,0,5
politics,0,0,0,88,5,0
religion,1,0,1,2,81,0
software,1,1,1,3,1,75


Cross-validating result to check variance across multiple train folds

In [70]:
cross_val_score(pipeline, X, y, cv=5, scoring='f1_weighted')

array([ 0.94177035,  0.9501224 ,  0.93900446,  0.96123971,  0.96394791])

Investigating a few incorrect predictions

In [74]:
df_test = pd.DataFrame({'text': text_test, 'y_true': y_test, 'y_pred': pred_y_test})

df_test.head()

Unnamed: 0,text,y_pred,y_true
597,"In trigonometry and geometry, triangulation is...",math,math
831,"John Leslie ""Wes"" Montgomery (6 March 1923 - 1...",music,music
1174,The Industrial Workers of the World (IWW or th...,politics,politics
467,In mathematics a graph is an abstract represen...,math,math
1722,ArcGIS is a suite consisting of a group of geo...,software,software


In [79]:
software_fn = df_test.loc[(df_test.y_pred != df_test.y_true) & (df_test.y_true == 'software')]

software_fn

Unnamed: 0,text,y_pred,y_true
1558,Google Street View is a feature of Google Maps...,politics,software
1637,The following is a list of notable feed aggreg...,food,software
1565,"In aviation, V-speeds or Velocity-speeds are s...",politics,software
1507,BitTorrent may refer to:,religion,software
1676,"In the theory of computation, a nondeterminist...",math,software
1623,A photo booth is a vending machine or modern k...,music,software
1616,"The Indian Institutes of Technology (IITs), ar...",politics,software


In [83]:
for _, row in software_fn.iterrows():
    print(row['y_pred'])
    print(row['text'][:1000], end='\n\n')

politics
Google Street View is a feature of Google Maps and Google Earth that provides for many streets in the world 360° horizontal and 290° vertical panoramic views from a row of positions along the street (one in every 10 or 20 metres, or so), from a height of about two metres. It was launched on May 25, 2007, and is gradually expanded to include more cities, and in these cities more streets, and also some rural areas. These photographs are currently available for countries including the United States, the United Kingdom, the Netherlands, France, Italy, Spain, Australia, New Zealand and Japan. Coverage is shown by dragging "pegman" from its position, on a map of any scale. Google Street View displays photos taken from a fleet of Chevrolet Cobalts in United States, Opel Astras in Europe and Australia, Vauxhall Astras in the United Kingdom and Toyota Prius cars in Japan. Pedestrian areas, narrow streets and park alleys that cannot be accessed by car are not always covered. However, so

Most of these mis-classifications are understandable, it seems that "geography" related documents are being associated with politics so in the case of software articles that contain a lot of place names, we get a false negative.

Also note that some of the articles are very short, this is likely due to the html parsing only picking up `<p>` tags, when useful text will exist in other tags, for instance in both examples here with truncated text appear to be lists so may be `<ul>` or `<ol>` tags. This may also go some way to explaining the doc2vec clusters for "list" articles noticed in 03_doc2vec_evaluate.

For now this result is fine as a baseline, and will use it for validating our spaCy models, before training on more complex inputs.