In [1]:
%config InlineBackend.figure_format = 'retina'

In [2]:
%load_ext autoreload

%autoreload 1

In [3]:
import numpy as np

import pickle

from pathlib import Path

# Load data

In [4]:
data_root = Path.home() / "data" / "tmp"
reuters_corpus_path = data_root / "reuters21578" / "corpus.pkl"
reuters = pickle.load(open(reuters_corpus_path, "rb"))
top_ten_ids, top_ten_names = reuters.top_n(n=10)

# Simple linear model

### Get text and labels

In [5]:
train_docs, test_docs = reuters.split_modapte()
print(len(train_docs), len(test_docs))

train = [d["text"] for d in train_docs]
train_labels = reuters.get_labels(train_docs, set(top_ten_ids))
y_train = np.array(train_labels)

test = [d["text"] for d in test_docs]
test_labels = reuters.get_labels(test_docs, set(top_ten_ids))
y_test = np.array(test_labels)

7770 3019


### Vectorize texts

In [6]:
from sklearn.feature_extraction.text import TfidfVectorizer

In [7]:
vectorizer = TfidfVectorizer()
vectorizer.fit(train)

TfidfVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.float64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), norm='l2', preprocessor=None, smooth_idf=True,
        stop_words=None, strip_accents=None, sublinear_tf=False,
        token_pattern='(?u)\\b\\w\\w+\\b', tokenizer=None, use_idf=True,
        vocabulary=None)

In [8]:
X_train = vectorizer.transform(train)
X_test = vectorizer.transform(test)

### Test models

In [9]:
from sklearn.metrics import classification_report

#### Logistic regression

In [21]:
from sklearn.linear_model import LogisticRegression
#model = LogisticRegression(C=100., solver="liblinear", multi_class="ovr")
model = LogisticRegression(solver="liblinear", multi_class="ovr")
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

In [22]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

              precision    recall  f1-score   support

        earn       0.98      0.99      0.98      1087
         acq       0.94      0.98      0.96       710
    money-fx       0.69      0.82      0.75       145
       grain       0.62      0.55      0.58        42
       crude       0.80      0.88      0.84       164
       trade       0.68      0.85      0.76       109
    interest       0.80      0.72      0.76       117
        ship       0.66      0.63      0.65        71
       wheat       0.75      0.73      0.74        55
        corn       0.66      0.69      0.67        45

   micro avg       0.89      0.92      0.91      2545
   macro avg       0.76      0.78      0.77      2545
weighted avg       0.89      0.92      0.91      2545



#### Linear Support Vector Machine

In [12]:
from sklearn.svm import LinearSVC
model = LinearSVC()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

In [13]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids, digits=3))

              precision    recall  f1-score   support

        earn      0.989     0.981     0.985      1087
         acq      0.926     0.982     0.953       710
    money-fx      0.731     0.807     0.767       145
       grain      0.657     0.548     0.597        42
       crude      0.802     0.890     0.844       164
       trade      0.734     0.835     0.781       109
    interest      0.815     0.752     0.782       117
        ship      0.738     0.634     0.682        71
       wheat      0.750     0.709     0.729        55
        corn      0.640     0.711     0.674        45

   micro avg      0.900     0.921     0.911      2545
   macro avg      0.778     0.785     0.779      2545
weighted avg      0.902     0.921     0.910      2545



#### Naive Bayes

In [14]:
from sklearn.naive_bayes import MultinomialNB

model = MultinomialNB()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

In [15]:
print(classification_report(y_test, y_pred, target_names=top_ten_names, labels=top_ten_ids))

              precision    recall  f1-score   support

        earn       0.68      0.99      0.80      1087
         acq       0.56      0.94      0.70       710
    money-fx       0.66      0.54      0.60       145
       grain       0.00      0.00      0.00        42
       crude       0.93      0.26      0.40       164
       trade       0.91      0.28      0.43       109
    interest       0.81      0.18      0.29       117
        ship       1.00      0.01      0.03        71
       wheat       0.00      0.00      0.00        55
        corn       0.00      0.00      0.00        45

   micro avg       0.63      0.75      0.69      2545
   macro avg       0.55      0.32      0.33      2545
weighted avg       0.65      0.75      0.63      2545



  'precision', 'predicted', average, warn_for)
