In [1]:
%config InlineBackend.figure_format = 'retina'

In [2]:
%load_ext autoreload

%autoreload 1

In [3]:
import numpy as np

import pickle

from pathlib import Path

# Load data

In [4]:
data_root = Path.home() / "data" / "tmp"
reuters_corpus_path = data_root / "reuters21578" / "corpus.pkl"
reuters = pickle.load(open(reuters_corpus_path, "rb"))
top_ten_ids, top_ten_names = reuters.top_n(n=10)

# Simple linear model

### Get text and labels

In [5]:
train_docs, test_docs = reuters.split_modapte()
print(len(train_docs), len(test_docs))

train = [d["text"] for d in train_docs]
train_labels = reuters.get_labels(train_docs, set(top_ten_ids))
y_train = np.array(train_labels)

test = [d["text"] for d in test_docs]
test_labels = reuters.get_labels(test_docs, set(top_ten_ids))
y_test = np.array(test_labels)

7770 3019


### Vectorize texts

In [6]:
from sklearn.feature_extraction.text import TfidfVectorizer

In [7]:
vectorizer = TfidfVectorizer()
vectorizer.fit(train)

TfidfVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), norm='l2', preprocessor=None, smooth_idf=True,
        stop_words=None, strip_accents=None, sublinear_tf=False,
        token_pattern='(?u)\\b\\w\\w+\\b', tokenizer=None, use_idf=True,
        vocabulary=None)

In [8]:
X_train = vectorizer.transform(train)
X_test = vectorizer.transform(test)

### Test models

In [9]:
from sklearn.metrics import classification_report

#### Logistic regression

In [10]:
from sklearn.linear_model import LogisticRegression
#model = LogisticRegression(C=100.)
model = LogisticRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

In [11]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.94      0.98      0.96      1087
        acq       0.77      0.98      0.87       710
   money-fx       0.70      0.77      0.73       145
      grain       0.68      0.36      0.47        42
      crude       0.76      0.84      0.79       164
      trade       0.63      0.89      0.74       109
   interest       0.78      0.72      0.75       117
       ship       0.71      0.56      0.63        71
      wheat       0.77      0.62      0.69        55
       corn       0.63      0.69      0.66        45

avg / total       0.83      0.91      0.86      2545



#### Linear Support Vector Machine

In [14]:
from sklearn.svm import LinearSVC
model = LinearSVC()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

In [15]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.99      0.98      0.98      1087
        acq       0.93      0.98      0.95       710
   money-fx       0.73      0.81      0.77       145
      grain       0.66      0.55      0.60        42
      crude       0.80      0.89      0.84       164
      trade       0.73      0.83      0.78       109
   interest       0.81      0.75      0.78       117
       ship       0.74      0.63      0.68        71
      wheat       0.75      0.71      0.73        55
       corn       0.64      0.71      0.67        45

avg / total       0.90      0.92      0.91      2545



#### RBF SVM

In [18]:
from sklearn.svm import SVC
model = SVC()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

In [19]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.36      1.00      0.53      1087
        acq       0.00      0.00      0.00       710
   money-fx       0.00      0.00      0.00       145
      grain       0.00      0.00      0.00        42
      crude       0.00      0.00      0.00       164
      trade       0.00      0.00      0.00       109
   interest       0.00      0.00      0.00       117
       ship       0.00      0.00      0.00        71
      wheat       0.00      0.00      0.00        55
       corn       0.00      0.00      0.00        45

avg / total       0.15      0.43      0.23      2545



  'precision', 'predicted', average, warn_for)


#### Naive Bayes

In [16]:
from sklearn.naive_bayes import MultinomialNB

model = MultinomialNB()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

In [17]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.68      0.99      0.80      1087
        acq       0.56      0.94      0.70       710
   money-fx       0.66      0.54      0.60       145
      grain       0.00      0.00      0.00        42
      crude       0.93      0.26      0.40       164
      trade       0.91      0.28      0.43       109
   interest       0.81      0.18      0.29       117
       ship       1.00      0.01      0.03        71
      wheat       0.00      0.00      0.00        55
       corn       0.00      0.00      0.00        45

avg / total       0.65      0.75      0.63      2545



  'precision', 'predicted', average, warn_for)


In [113]:
from sklearn.neighbors import KNeighborsClassifier

model = KNeighborsClassifier(n_neighbors=10)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

In [114]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

             precision    recall  f1-score   support

       earn       0.43      0.82      0.56      1087
        acq       0.93      0.05      0.10       710
   money-fx       0.19      0.10      0.13       145
      grain       0.50      0.02      0.05        42
      crude       0.46      0.26      0.33       164
      trade       0.57      0.07      0.13       109
   interest       0.52      0.31      0.39       117
       ship       0.04      0.03      0.03        71
      wheat       0.50      0.11      0.18        55
       corn       0.64      0.16      0.25        45

avg / total       0.56      0.41      0.33      2545

