Multiclass Classifiers

Load a dataset and train two models to perform multiclass classification and compare the results of the models. The dataset is the digits dataset available from the sklearn's datasets library. This dataset contain 1797 samples of written digits. The goal is to correctly identify digits from 0 to 9.

In [1]:
# Load the data

from sklearn.datasets import load_digits
X, y = load_digits(return_X_y=True)



In [2]:
# Examine the data 


import numpy as np

print('The number of rows in the dataset is {:d}'.format(X.shape[0]))
print('The number of features in the dataset is {:d}'.format(X.shape[1]))
np.bincount(y)

The number of rows in the dataset is 1797
The number of features in the dataset is 64


array([178, 182, 177, 183, 181, 182, 181, 179, 174, 180], dtype=int64)

In [3]:
# Prepare training and testing data

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=40)

In [4]:
# Get cross validation with logistic regression

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

lr_clf = LogisticRegression(solver='lbfgs', multi_class='ovr', max_iter=1000)
lr_cv_scores = cross_val_score(lr_clf, X_train, y_train, cv = 5)

print('Accuracy scores for the 5 folds: ', lr_cv_scores)
print('Mean cross validatiion score: {:.3f}'.format(np.mean(lr_cv_scores)))

Accuracy scores for the 5 folds:  [0.94791667 0.94791667 0.95470383 0.94425087 0.95470383]
Mean cross validatiion score: 0.950


In [5]:
# Cross validation with RandomForest

from sklearn.ensemble import RandomForestClassifier

rf_clf = RandomForestClassifier(n_estimators=24)
rf_cv_scores = cross_val_score(rf_clf, X_train, y_train, cv = 5)

print('Accuracy scores for the 5 folds: ', rf_cv_scores)
print('Mean cross validatiion score: {:.3f}'.format(np.mean(rf_cv_scores)))

Accuracy scores for the 5 folds:  [0.96875    0.95486111 0.95818815 0.96864111 0.95470383]
Mean cross validatiion score: 0.961
