In [None]:
%config InlineBackend.figure_format = 'retina'

In [None]:
%load_ext autoreload

%autoreload 1

In [None]:
import numpy as np

import pickle

from pathlib import Path

# Load data

In [None]:
from ds_tutorial.datasets import ReutersCorpus

In [None]:
data_root = Path.home() / "data" / "tmp"
reuters_corpus_path = data_root / "reuters21578" / "corpus.pkl"
reuters = pickle.load(open(reuters_corpus_path, "rb"))
reuters_documents_path = data_root / "reuters21578" / "documents.pkl"
top_ten_ids, top_ten_names = reuters.top_n(n=10)

In [None]:
top_ten_ids, top_ten_names

([3, 10, 13, 17, 4, 9, 0, 19, 22, 18],
 ['earn',
  'acq',
  'money-fx',
  'grain',
  'crude',
  'trade',
  'interest',
  'ship',
  'wheat',
  'corn'])

# Simple linear model

### Get text and labels

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer

In [None]:
train_docs, test_docs = reuters.split_modapte()
print(len(train_docs), len(test_docs))

train = [d["text"] for d in train_docs]
train_labels = reuters.get_all_labels(train_docs)
y_train = MultiLabelBinarizer().fit_transform(train_labels)

test = [d["text"] for d in test_docs]
test_labels = reuters.get_all_labels(test_docs)
y_test = MultiLabelBinarizer().fit_transform(test_labels)

7770 3019


In [None]:
from sklearn.preprocessing import MultiLabelBinarizer
train_labels_binarized = MultiLabelBinarizer().fit_transform(train_labels)
test_labels_binarized = MultiLabelBinarizer().fit_transform(test_labels)

In [None]:
train_labels_binarized.shape, test_labels_binarized.shape

((7770, 90), (3019, 90))

### Vectorize texts

In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer

In [None]:
vectorizer = TfidfVectorizer()
vectorizer.fit(train)

TfidfVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=<class 'numpy.float64'>, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), norm='l2', preprocessor=None, smooth_idf=True,
        stop_words=None, strip_accents=None, sublinear_tf=False,
        token_pattern='(?u)\\b\\w\\w+\\b', tokenizer=None, use_idf=True,
        vocabulary=None)

In [None]:
X_train = vectorizer.transform(train)
X_test = vectorizer.transform(test)

### Test models

In [None]:
from sklearn.metrics import classification_report
from sklearn.multiclass import OneVsRestClassifier

#### Logistic regression

In [None]:
from sklearn.linear_model import LogisticRegression
# model = OneVsRestClassifier(LogisticRegression(C=100, solver="liblinear", multi_class="ovr"))
model = OneVsRestClassifier(LogisticRegression(solver="liblinear", multi_class="ovr"))
model.fit(X_train, train_labels_binarized)
y_pred = model.predict(X_test)

In [None]:
print(classification_report(test_labels_binarized, y_pred, target_names=top_ten_names, labels=top_ten_ids))

              precision    recall  f1-score   support

        earn       0.99      0.97      0.98      1087
         acq       0.98      0.92      0.95       719
    money-fx       0.78      0.51      0.62       179
       grain       0.99      0.60      0.75       149
       crude       0.96      0.57      0.72       189
       trade       0.93      0.54      0.68       117
    interest       0.91      0.47      0.62       131
        ship       1.00      0.13      0.24        89
       wheat       0.97      0.51      0.67        71
        corn       0.95      0.32      0.48        56

   micro avg       0.97      0.79      0.87      2787
   macro avg       0.95      0.56      0.67      2787
weighted avg       0.97      0.79      0.85      2787
 samples avg       0.70      0.69      0.69      2787



  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


#### Linear Support Vector Machine

In [None]:
from sklearn.svm import LinearSVC
model = OneVsRestClassifier(LinearSVC())
model.fit(X_train, train_labels_binarized)
y_pred = model.predict(X_test)

In [None]:
print(classification_report(test_labels_binarized, y_pred, target_names=top_ten_names, labels=top_ten_ids, digits=3))

              precision    recall  f1-score   support

        earn      0.991     0.980     0.985      1087
         acq      0.984     0.950     0.967       719
    money-fx      0.810     0.788     0.799       179
       grain      0.975     0.799     0.878       149
       crude      0.906     0.868     0.886       189
       trade      0.830     0.709     0.765       117
    interest      0.870     0.664     0.753       131
        ship      0.924     0.685     0.787        89
       wheat      0.929     0.732     0.819        71
        corn      0.955     0.750     0.840        56

   micro avg      0.956     0.896     0.925      2787
   macro avg      0.917     0.793     0.848      2787
weighted avg      0.954     0.896     0.922      2787
 samples avg      0.771     0.769     0.767      2787



  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


#### Naive Bayes

In [None]:
from sklearn.naive_bayes import MultinomialNB

model = OneVsRestClassifier(MultinomialNB())
model.fit(X_train, train_labels_binarized)
y_pred = model.predict(X_test)

In [None]:
print(classification_report(test_labels_binarized, y_pred, target_names=top_ten_names, labels=top_ten_ids))

              precision    recall  f1-score   support

        earn       1.00      0.91      0.95      1087
         acq       1.00      0.29      0.44       719
    money-fx       0.00      0.00      0.00       179
       grain       1.00      0.03      0.05       149
       crude       1.00      0.02      0.03       189
       trade       0.00      0.00      0.00       117
    interest       0.00      0.00      0.00       131
        ship       0.00      0.00      0.00        89
       wheat       0.00      0.00      0.00        71
        corn       0.00      0.00      0.00        56

   micro avg       1.00      0.43      0.60      2787
   macro avg       0.40      0.12      0.15      2787
weighted avg       0.77      0.43      0.49      2787
 samples avg       0.40      0.40      0.40      2787



  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)
