# Chapter 3: Classification

**Reference:** Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow (Aurélien Géron)

---

## 1. Chapter Introduction

In Chapter 1, we mentioned that the most common supervised learning tasks are regression (predicting values) and classification (predicting classes). In Chapter 2, we explored a regression task, predicting housing values, using various algorithms such as Linear Regression, Decision Trees, and Random Forests. Now we will turn our attention to classification systems.

**MNIST**

In this chapter, we will be using the MNIST dataset, which is a set of 70,000 small images of digits handwritten by high school students and employees of the US Census Bureau. Each image is labeled with the digit it represents. This set has been studied so much that it is often called the "hello world" of Machine Learning: whenever people come up with a new classification algorithm, they are curious to see how it will perform on MNIST, and anyone who learns Machine Learning tackles this dataset sooner or later.

We will start by looking at binary classification (distinguishing between two classes), then move on to multiclass classification (more than two classes). We will also discuss multilabel classification (assigning multiple labels to each instance) and multioutput classification. Along the way, we will learn about various performance measures, which are essential for evaluating classifiers.

## 2. Theoretical Explanations

### A. Binary Classification
A binary classifier is capable of distinguishing between just two classes. For example, to build a system that detects the number 5, we would create a "5-detector" that distinguishes between two classes: "5" and "not-5".

### B. Performance Measures

Evaluating a classifier is often significantly trickier than evaluating a regressor.

**1. Measuring Accuracy Using Cross-Validation**
A good way to evaluate a model is to use cross-validation. However, accuracy is generally not the preferred performance measure for classifiers, especially when you are dealing with **skewed datasets** (i.e., when some classes are much more frequent than others). For example, if only 10% of the images are 5s, a classifier that always guesses "not-5" will have 90% accuracy! This demonstrates why accuracy alone can be misleading.

**2. Confusion Matrix**
A much better way to evaluate the performance of a classifier is to look at the confusion matrix. The general idea is to count the number of times instances of class A are classified as class B. For example, to know how many times the classifier confused images of 5s with 3s, you would look in the 5th row and 3rd column of the confusion matrix.

A confusion matrix has the following structure for a binary classifier:

| | **Predicted Negative** | **Predicted Positive** |
|---|---|---|
| **Actual Negative** | True Negative (TN) | False Positive (FP) |
| **Actual Positive** | False Negative (FN) | True Positive (TP) |

* **True Negatives (TN):** Correctly classified negative instances.
* **False Positives (FP):** Negative instances incorrectly classified as positive (Type I error).
* **False Negatives (FN):** Positive instances incorrectly classified as negative (Type II error).
* **True Positives (TP):** Correctly classified positive instances.

**3. Precision and Recall**
The confusion matrix gives you a lot of information, but sometimes you may prefer a more concise metric. 

**Precision** is the accuracy of the positive predictions:
$$ Precision = \frac{TP}{TP + FP} $$

**Recall** (also called sensitivity or true positive rate) is the ratio of positive instances that are correctly detected by the classifier:
$$ Recall = \frac{TP}{TP + FN} $$

**4. F1 Score**
It is often convenient to combine precision and recall into a single metric called the $F_1$ score, in particular if you need a simple way to compare two classifiers. The $F_1$ score is the harmonic mean of precision and recall. Whereas the regular mean treats all values equally, the harmonic mean gives much more weight to low values. As a result, the classifier will only get a high $F_1$ score if both recall and precision are high.

$$ F_1 = \frac{2}{\frac{1}{precision} + \frac{1}{recall}} = 2 \times \frac{precision \times recall}{precision + recall} = \frac{TP}{TP + \frac{FN + FP}{2}} $$

**5. The Precision/Recall Trade-off**
Ideally, you would want high precision and high recall, but unfortunately, increasing precision reduces recall, and vice versa. This is called the *precision/recall trade-off*. To understand this trade-off, let's consider how the classifier makes its decisions. For each instance, it computes a score based on a *decision function*, and if that score is greater than a threshold, it assigns the instance to the positive class, or else it assigns it to the negative class. Raising the threshold increases precision (but may decrease recall), while lowering the threshold increases recall (but may decrease precision).

**6. The ROC Curve**
The *Receiver Operating Characteristic* (ROC) curve is another common tool used with binary classifiers. It is very similar to the precision/recall curve, but instead of plotting precision versus recall, the ROC curve plots the *true positive rate* (another name for recall) against the *false positive rate* (FPR). The FPR is the ratio of negative instances that are incorrectly classified as positive. It is equal to one minus the *true negative rate*, which is the ratio of negative instances that are correctly classified as negative. The TNR is also called *specificity*. Hence the ROC curve plots sensitivity (recall) versus 1 – specificity.

$$ FPR = \frac{FP}{TN + FP} $$

To compare classifiers, we can measure the *Area Under the Curve* (AUC). A perfect classifier will have a ROC AUC equal to 1, whereas a purely random classifier will have a ROC AUC equal to 0.5.

### C. Multiclass Classification
Multiclass classifiers (also called *multinomial classifiers*) can distinguish between more than two classes. 

* **Native Support:** Some algorithms (such as Random Forest classifiers or naive Bayes classifiers) are capable of handling multiple classes directly.
* **Binary Strategies:** Others (such as Support Vector Machine classifiers or Linear classifiers) are strictly binary classifiers. However, there are various strategies that you can use to perform multiclass classification using multiple binary classifiers.
    * **One-versus-the-Rest (OvR):** For example, one way to create a system that can classify the digit images into 10 classes (from 0 to 9) is to train 10 binary classifiers, one for each digit (a 0-detector, a 1-detector, a 2-detector, and so on). Then when you want to classify an image, you get the decision score from each classifier for that image and you select the class whose classifier outputs the highest score.
    * **One-versus-One (OvO):** Another strategy is to train a binary classifier for every pair of digits: one to distinguish 0s and 1s, another to distinguish 0s and 2s, another for 1s and 2s, and so on. If there are $N$ classes, you need to train $N \times (N – 1) / 2$ classifiers. For the MNIST problem, this means training 45 binary classifiers!

Scikit-Learn detects when you try to use a binary classification algorithm for a multiclass classification task, and it automatically runs OvR or OvO, depending on the algorithm.

### D. Error Analysis
If you have found a promising model and you want to find ways to improve it, one way to do this is to analyze the types of errors it makes. First, you can look at the confusion matrix. You need to make predictions using the `cross_val_predict()` function, then call the `confusion_matrix()` function. It is often more useful to look at an image representation of the confusion matrix, using Matplotlib's `matshow()` function.

### E. Multilabel Classification
Until now each instance has always been assigned to just one class. In some cases you may want your classifier to output multiple classes for each instance. For example, consider a face-recognition classifier: what should it do if it recognizes several people on the same picture? Of course it should attach one tag per person it recognizes. Say the classifier has been trained to recognize three faces, Alice, Bob, and Charlie; then when it is shown a picture of Alice and Charlie, it should output [1, 0, 1] (meaning "Alice yes, Bob no, Charlie yes"). Such a classification system that outputs multiple binary tags is called a *multilabel classification* system.

### F. Multioutput Classification
The last type of classification task we are going to discuss here is called *multioutput-multiclass classification* (or simply *multioutput classification*). It is simply a generalization of multilabel classification where each label can be multiclass (i.e., it can have more than two possible values).

To illustrate this, let's build a system that removes noise from images. It will take as input a noisy digit image, and it will (hopefully) output a clean digit image, represented as an array of pixel intensities, just like the MNIST images. Notice that the classifier's output is multilabel (one label per pixel) and each label can have multiple values (pixel intensity ranges from 0 to 255). It is thus an example of a multioutput classification system.

## 3. Step-by-Step Implementation

### A. Getting the MNIST Data
Scikit-Learn provides many helper functions to download popular datasets. MNIST is one of them. The following code fetches the MNIST dataset.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml

# Fetch MNIST dataset
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
X, y = mnist["data"], mnist["target"]

# Convert labels from string to integer
y = y.astype(np.uint8)

print("Data shape:", X.shape)
print("Labels shape:", y.shape)

There are 70,000 images, and each image has 784 features. This is because each image is 28×28 pixels, and each feature simply represents one pixel’s intensity, from 0 (white) to 255 (black). Let’s grab one digit from the dataset, reshape it to a 28×28 array, and display it using Matplotlib’s `imshow()` function:

In [None]:
some_digit = X[0]
some_digit_image = some_digit.reshape(28, 28)

plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()

print("Label:", y[0])

But wait! You should always create a test set and set it aside before inspecting the data closely. The MNIST dataset is actually already split into a training set (the first 60,000 images) and a test set (the last 10,000 images):

In [None]:
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

The training set is already shuffled for us, which is good as this guarantees that all cross-validation folds will be similar (you don’t want one fold to be missing some digits). Moreover, some learning algorithms are sensitive to the order of the training instances, and they perform poorly if they get many similar instances in a row. Shuffling the dataset ensures that this won’t happen.

### B. Training a Binary Classifier
Let’s simplify the problem for now and only try to identify one digit—for example, the number 5. This "5-detector" will be an example of a binary classifier, capable of distinguishing between just two classes, 5 and not-5. Let’s create the target vectors for this classification task:

In [None]:
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

Okay, now let’s pick a classifier and train it. A good place to start is with a **Stochastic Gradient Descent (SGD)** classifier, using Scikit-Learn’s `SGDClassifier` class. This classifier has the advantage of being capable of handling very large datasets efficiently. This is in part because SGD deals with training instances independently, one at a time (which also makes SGD well suited for online learning), as we will see later. Let’s create an `SGDClassifier` and train it on the whole training set:

In [None]:
from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

# Now you can use it to detect images of the number 5:
print("Prediction for 'some_digit' (should be True):", sgd_clf.predict([some_digit]))

### C. Performance Measures

**Measuring Accuracy Using Cross-Validation**

We will use the `cross_val_score()` function to evaluate our `SGDClassifier` model using K-fold cross-validation, with three folds. Remember that K-fold cross-validation means splitting the training set into K-folds (in this case, three), then making predictions and evaluating them on each fold using a model trained on the remaining folds.

In [None]:
from sklearn.model_selection import cross_val_score

accuracy_scores = cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
print("Accuracy Scores:", accuracy_scores)

Wow! Above 95% accuracy (ratio of correct predictions) on all cross-validation folds? This looks amazing, doesn’t it? Well, before you get too excited, let’s look at a very dumb classifier that just classifies every single image in the "not-5" class:

In [None]:
from sklearn.base import BaseEstimator

class Never5Classifier(BaseEstimator):
    def fit(self, X, y=None):
        pass
    def predict(self, X):
        return np.zeros((len(X), 1), dtype=bool)

never_5_clf = Never5Classifier()
never_5_scores = cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")
print("Never5 Classifier Scores:", never_5_scores)

That’s right, it has over 90% accuracy! This is simply because only about 10% of the images are 5s, so if you always guess that an image is not a 5, you will be right about 90% of the time. This demonstrates why accuracy is generally not the preferred performance measure for classifiers, especially when you are dealing with skewed datasets.

**Confusion Matrix**

A much better way to evaluate the performance of a classifier is to look at the confusion matrix. To compute the confusion matrix, you first need to have a set of predictions, so they can be compared to the actual targets. You could make predictions on the test set, but let’s keep it untouched for now. Instead, you can use the `cross_val_predict()` function:

In [None]:
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix

y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
cm = confusion_matrix(y_train_5, y_train_pred)
print("Confusion Matrix:\n", cm)

Each row in a confusion matrix represents an actual class, while each column represents a predicted class. The first row of this matrix considers non-5 images (the negative class): 53,057 of them were correctly classified as non-5s (**true negatives**), while the remaining were wrongly classified as 5s (**false positives**). The second row considers the images of 5s (the positive class): some were wrongly classified as non-5s (**false negatives**), while the remaining were correctly classified as 5s (**true positives**). A perfect classifier would have only true positives and true negatives, so its confusion matrix would have nonzero values only on its main diagonal (top left to bottom right).

**Precision and Recall**

Scikit-Learn provides several convenience functions to compute classifier metrics, including precision and recall.

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

print("Precision:", precision_score(y_train_5, y_train_pred))
print("Recall:", recall_score(y_train_5, y_train_pred))
print("F1 Score:", f1_score(y_train_5, y_train_pred))

Now your 5-detector does not look as shiny as it did when you looked at its accuracy. When it claims an image represents a 5, it is correct only part of the time. Moreover, it only detects a fraction of the 5s.

**Precision/Recall Trade-off**

To understand this trade-off, let's look at how the `SGDClassifier` makes its decisions. For each instance, it computes a score based on a decision function. If that score is greater than a threshold, it assigns the instance to the positive class; otherwise it assigns it to the negative class.

Scikit-Learn does not let you set the threshold directly, but it does give you access to the decision scores that it uses to make predictions. Instead of calling the classifier’s `predict()` method, you can call its `decision_function()` method, which returns a score for each instance, and then make predictions based on those scores using any threshold you want:

In [None]:
y_scores = sgd_clf.decision_function([some_digit])
print("Score:", y_scores)
threshold = 0
y_some_digit_pred = (y_scores > threshold)
print("Prediction with threshold 0:", y_some_digit_pred)

The `SGDClassifier` uses a threshold of 0, so the previous code returns the same result as the `predict()` method (i.e., True). Let’s raise the threshold:

In [None]:
threshold = 8000
y_some_digit_pred = (y_scores > threshold)
print("Prediction with threshold 8000:", y_some_digit_pred)

This confirms that raising the threshold decreases recall. The image actually represents a 5, and the classifier detects it when the threshold is 0, but it misses it when the threshold is increased to 8,000.

How can you decide which threshold to use? For this you will first need to get the scores of all instances in the training set using `cross_val_predict()` again, but this time specifying that you want it to return decision scores instead of predictions:

In [None]:
from sklearn.metrics import precision_recall_curve

y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method="decision_function")
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
    plt.legend()
    plt.xlabel("Threshold")
    plt.grid(True)

plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()

**The ROC Curve**

The receiver operating characteristic (ROC) curve is another common tool used with binary classifiers. It is very similar to the precision/recall curve, but instead of plotting precision versus recall, the ROC curve plots the true positive rate (another name for recall) against the false positive rate. The FPR is the ratio of negative instances that are incorrectly classified as positive.

In [None]:
from sklearn.metrics import roc_curve, roc_auc_score

fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)

def plot_roc_curve(fpr, tpr, label=None):
    plt.plot(fpr, tpr, linewidth=2, label=label)
    plt.plot([0, 1], [0, 1], 'k--') # Dashed diagonal
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate (Recall)')
    plt.grid(True)

plot_roc_curve(fpr, tpr)
plt.show()

print("ROC AUC Score:", roc_auc_score(y_train_5, y_scores))

### D. Multiclass Classification

Now let’s try to detect more than just the 5s. Some algorithms (such as Random Forest classifiers or naive Bayes classifiers) are capable of handling multiple classes directly. Others (such as Support Vector Machine classifiers or Linear classifiers) are strictly binary classifiers. However, there are various strategies that you can use to perform multiclass classification using multiple binary classifiers (OvR, OvO).

Scikit-Learn detects when you try to use a binary classification algorithm for a multiclass classification task, and it automatically runs OvR or OvO, depending on the algorithm. Let’s try this with the `SGDClassifier`:

In [None]:
sgd_clf.fit(X_train, y_train) # y_train, not y_train_5
print("Prediction for some_digit (multiclass):", sgd_clf.predict([some_digit]))

This code trains the `SGDClassifier` on the training set using the original target classes from 0 to 9. Under the hood, Scikit-Learn actually trained 10 binary classifiers, got their decision scores for the image, and selected the class with the highest score.

To validate this, we can call the `decision_function()` method. Instead of returning just one score per instance, it now returns 10 scores, one per class:

In [None]:
some_digit_scores = sgd_clf.decision_function([some_digit])
print("Scores for some_digit:", some_digit_scores)
print("Class with max score:", np.argmax(some_digit_scores))

If you want to force ScikitLearn to use one-versus-one or one-versus-all, you can use the `OneVsOneClassifier` or `OneVsRestClassifier` classes. Simply create an instance and pass a binary classifier to its constructor. For example, this code creates a multiclass classifier using the OvO strategy, based on an `SGDClassifier`:

In [None]:
from sklearn.multiclass import OneVsOneClassifier
ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42))
ovo_clf.fit(X_train, y_train)
print("OvO Prediction:", ovo_clf.predict([some_digit]))
print("Number of estimators:", len(ovo_clf.estimators_))

Training a `RandomForestClassifier` is just as easy:

In [None]:
from sklearn.ensemble import RandomForestClassifier

forest_clf = RandomForestClassifier(random_state=42)
forest_clf.fit(X_train, y_train)
print("Random Forest Prediction:", forest_clf.predict([some_digit]))

This time Scikit-Learn did not have to run OvR or OvO because Random Forest classifiers can directly classify instances into multiple classes. You can call `predict_proba()` to get the list of probabilities that the classifier assigned to each instance for each class:

Now let’s evaluate these classifiers using cross-validation. Let’s evaluate the `SGDClassifier`’s accuracy using the `cross_val_score()` function:

In [None]:
scores = cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
print("SGD Multiclass Scores:", scores)

It gets over 84% on all test folds. If you used a random classifier, you would get 10% accuracy, so this is not such a bad score, but you can still do much better. For example, simply scaling the inputs (as discussed in Chapter 2) increases accuracy above 89%:

In [None]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
scaled_scores = cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")
print("Scaled SGD Multiclass Scores:", scaled_scores)

### E. Error Analysis

Of course, if this were a real project, you would follow the steps in your Machine Learning project checklist: exploring data preparation options, trying out multiple models, shortlisting the best ones, and fine-tuning their hyperparameters using `GridSearchCV`. However, let’s assume that you have found a promising model and you want to find ways to improve it. One way to do this is to analyze the types of errors it makes.

First, you can look at the confusion matrix. You need to make predictions using the `cross_val_predict()` function, then call the `confusion_matrix()` function:

In [None]:
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
print("Multiclass Confusion Matrix:\n", conf_mx)

That’s a lot of numbers. It’s often more convenient to look at an image representation of the confusion matrix, using Matplotlib’s `matshow()` function:

In [None]:
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()

This confusion matrix looks fairly good, since most images are on the main diagonal, which means that they were classified correctly. The 5s look slightly darker than the other digits, which could mean that there are fewer images of 5s in the dataset or that the classifier does not perform as well on 5s as on other digits. In fact, you can verify that both are the case.

Let’s focus on the plot of the errors. First, you need to divide each value in the confusion matrix by the number of images in the corresponding class, so you can compare error rates instead of absolute number of errors (which would make abundant classes look unfairly bad). Then fill the diagonal with zeros to keep only the errors, and plot the result:

In [None]:
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()

Now you can clearly see the kinds of errors the classifier makes. Remember that rows represent actual classes, while columns represent predicted classes. The column for class 8 is quite bright, which tells you that many images get misclassified as 8s. However, the row for class 8 is not that bad, telling you that actual 8s in general get properly classified as 8s. As you can see, the confusion matrix is not necessarily symmetrical. You can also see that 3s and 5s often get confused (in both directions).

### F. Multilabel Classification

Until now each instance has always been assigned to just one class. In some cases you may want your classifier to output multiple classes for each instance. For example, consider a face-recognition classifier: what should it do if it recognizes several people on the same picture? Of course it should attach one tag per person it recognizes. Say the classifier has been trained to recognize three faces, Alice, Bob, and Charlie; then when it is shown a picture of Alice and Charlie, it should output [1, 0, 1] (meaning "Alice yes, Bob no, Charlie yes"). Such a classification system that outputs multiple binary tags is called a *multilabel classification* system.

We won’t use facial recognition just yet, but let’s look at a simpler example, using the MNIST dataset. The code below creates a `y_multilabel` array containing two target labels for each digit image: the first indicates whether or not the digit is large (7, 8, or 9), and the second indicates whether or not it is odd. The next lines create a `KNeighborsClassifier` instance (which supports multilabel classification, but not all classifiers do), and we train it using the multiple targets array. Now you can make a prediction, and notice that it outputs two labels:

In [None]:
from sklearn.neighbors import KNeighborsClassifier

y_train_large = (y_train >= 7)
y_train_odd = (y_train % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd]

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)

print("Multilabel prediction for '5' (Large?, Odd?):", knn_clf.predict([some_digit]))

And it gets it right! The digit 5 is indeed not large (False) and odd (True). There are many ways to evaluate a multilabel classifier, and selecting the right metric depends on your project. For example, one approach is to measure the $F_1$ score for each individual label (or any other binary classifier metric discussed earlier), then simply compute the average score. This assumes that all labels are equally important, which may not be the case. In particular, if you have many more pictures of Alice than of Bob or Charlie, you may want to give more weight to the classifier’s score on pictures of Alice. One simple option is to give each label a weight equal to its support (i.e., the number of instances with that target label). To do this, simply set `average="weighted"` in the preceding code.

### G. Multioutput Classification

The last type of classification task we are going to discuss here is called *multioutput-multiclass classification* (or simply *multioutput classification*). It is simply a generalization of multilabel classification where each label can be multiclass (i.e., it can have more than two possible values).

To illustrate this, let’s build a system that removes noise from images. It will take as input a noisy digit image, and it will (hopefully) output a clean digit image, represented as an array of pixel intensities, just like the MNIST images. Notice that the classifier’s output is multilabel (one label per pixel) and each label can have multiple values (pixel intensity ranges from 0 to 255). It is thus an example of a multioutput classification system.

Let’s start by creating the training and test sets by taking the MNIST images and adding noise to their pixel intensities using NumPy’s `randint()` function. The target images will be the original images:

In [None]:
# Add noise to data
noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test

# Visualize the noisy vs clean images
def plot_digit(data):
    image = data.reshape(28, 28)
    plt.imshow(image, cmap = matplotlib.cm.binary,
               interpolation="nearest")
    plt.axis("off")

import matplotlib

some_index = 0
plt.subplot(121); plot_digit(X_test_mod[some_index])
plt.subplot(122); plot_digit(y_test_mod[some_index])
plt.show()

Now let’s train the classifier and make it clean this image:

In [None]:
knn_clf.fit(X_train_mod, y_train_mod)
clean_digit = knn_clf.predict([X_test_mod[some_index]])
plot_digit(clean_digit)

Looks close enough to the target! This concludes our tour of classification. You should now know how to select good metrics for classification tasks, pick the appropriate precision/recall trade-off, compare classifiers, and more generally build good classification systems for a variety of tasks.