# Logistic Regression Multi-class Classification


This notebook is modified from a tutorial in sklearn.


In [1]:
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
from sklearn.datasets import fetch_20newsgroups
twenty_train = fetch_20newsgroups(subset='train', 
                                  categories=categories, shuffle=True, random_state=42)
twenty_train.target_names

['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']

Since we have a multi-class label, we need to encode labels to integer classes. 

In [2]:
# notice that are target is multi-class
twenty_train.target

array([1, 1, 3, ..., 2, 2, 2])

#### Multinomial Logistic Regression

We need to set a few parameters if we want to give sklearn a heads-up that our target has more than 2 classes, so we set these parameters in the LogisticRegression() call:

- multi_class='multinomial'
- solver = 'lbfgs'

For one-versus-all, you should use liblinear solver, the default. For multinomial logistic regression, the solver has to be sag, saga, newton-cg or lbfgs. We chose lbfgs, which is a Newton-inspired optimizer that doesn't use a lot of memory (l is for limited memory). The sag solvers are variants of gradient descent.

In [3]:
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model.logistic import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, log_loss

pipe1 = Pipeline([
        ('tfidf', TfidfVectorizer()),
        ('logreg', LogisticRegression(multi_class='multinomial', solver='lbfgs',class_weight='balanced')),
])

pipe1.fit(twenty_train.data, twenty_train.target)



Pipeline(memory=None,
         steps=[('tfidf',
                 TfidfVectorizer(analyzer='word', binary=False,
                                 decode_error='strict',
                                 dtype=<class 'numpy.float64'>,
                                 encoding='utf-8', input='content',
                                 lowercase=True, max_df=1.0, max_features=None,
                                 min_df=1, ngram_range=(1, 1), norm='l2',
                                 preprocessor=None, smooth_idf=True,
                                 stop_words=None, strip_accents=None,
                                 sublinear_tf=False,
                                 token_pattern='(?u)\\b\\w\\w+\\b',
                                 tokenizer=None, use_idf=True,
                                 vocabulary=None)),
                ('logreg',
                 LogisticRegression(C=1.0, class_weight='balanced', dual=False,
                                    fit_intercept=True, intercep

In [4]:
# evaluate on test data
twenty_test = fetch_20newsgroups(subset='test', categories=categories, shuffle=True, random_state=42)
pred = pipe1.predict(twenty_test.data)

from sklearn import metrics
print(metrics.classification_report(twenty_test.target, pred,
     target_names=twenty_test.target_names))

print("Confusion matrix:\n", metrics.confusion_matrix(twenty_test.target, pred))

import numpy as np
print("\nOverall accuracy: ", np.mean(pred==twenty_test.target))

                        precision    recall  f1-score   support

           alt.atheism       0.95      0.81      0.87       319
         comp.graphics       0.85      0.96      0.90       389
               sci.med       0.93      0.88      0.90       396
soc.religion.christian       0.90      0.94      0.92       398

              accuracy                           0.90      1502
             macro avg       0.91      0.90      0.90      1502
          weighted avg       0.91      0.90      0.90      1502

Confusion matrix:
 [[258  13  14  34]
 [  3 374   6   6]
 [  5  41 347   3]
 [  6  11   5 376]]

Overall accuracy:  0.9021304926764314


In [5]:
probs = pipe1.predict_proba(twenty_test.data)
probs[:5]

array([[0.14242452, 0.1643475 , 0.5691282 , 0.12409978],
       [0.0410866 , 0.03622316, 0.88370887, 0.03898138],
       [0.28678063, 0.10662445, 0.39017533, 0.21641959],
       [0.91486571, 0.02017099, 0.02211033, 0.04285297],
       [0.13633055, 0.06384863, 0.10635021, 0.6934706 ]])