# Classifier Diagnostics

Task: plot a confusion matrix, find images that were misclassified

## Setup

You do not need to read or modify the code in this section to successfully complete this assignment.

In [None]:
# Import fastai code.
from fastai.vision.all import *

# Set a seed for reproducibility.
set_seed(0, reproducible=True)

Monkey-patch `plot_top_losses` because of a bug.

In [None]:
def _plot_top_losses(self, k, largest=True, **kwargs):
    losses,idx = self.top_losses(k, largest)
    if not isinstance(self.inputs, tuple): self.inputs = (self.inputs,)
    if isinstance(self.inputs[0], Tensor): inps = tuple(o[idx] for o in self.inputs)
    else: inps = self.dl.create_batch(self.dl.before_batch([tuple(o[i] for o in self.inputs) for i in idx]))
    b = inps + tuple(o[idx] for o in (self.targs if is_listy(self.targs) else (self.targs,)))
    x,y,its = self.dl._pre_show_batch(b, max_n=k)
    b_out = inps + tuple(o[idx] for o in (self.decoded if is_listy(self.decoded) else (self.decoded,)))
    x1,y1,outs = self.dl._pre_show_batch(b_out, max_n=k)
    if its is not None:
        plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), self.preds[idx], losses,  **kwargs)
ClassificationInterpretation.plot_top_losses = _plot_top_losses

### Set up the dataset

In [None]:
path = untar_data(URLs.PETS)/'images'

In [None]:
image_files = get_image_files(path).sorted()

In [None]:
# Cat images have filenames that start with a capital letter.
def is_cat(filename):
    return filename[0].isupper()

### Deliberately corrupt some of the image labels

In [None]:
FLIP_PROB = 0.25
correct_labels = [is_cat(path.name) for path in image_files]
corrupted_labels = [
    not correct_label if random.random() < FLIP_PROB else correct_label
    for correct_label in correct_labels]

Check how many labels are still correct.

In [None]:
sum(
    correct_label == corrupted_label
    for correct_label, corrupted_label in zip(correct_labels, corrupted_labels)
) / len(correct_labels)

### Train the classifier on the (corrupted) labels

In [None]:
dataloaders = ImageDataLoaders.from_lists(
    path=path, fnames=image_files, labels=corrupted_labels,
    valid_pct=0.2,
    seed=42,
    item_tfms=Resize(224)
)

In [None]:
learn = cnn_learner(
    dls=dataloaders,
    arch=resnet18,
    metrics=accuracy
)
learn.fine_tune(epochs=4)
learn.recorder.plot_loss()

## Task

We've given you a classifier (the `learn` object). It turns out that it was trained on a *corrupted* dataset where some of the labels were flipped, but let's pretend that we didn't know that. Could we figure out where the problems are by looking at the results of the classification?

Follow these steps:

1. Show one batch from each of the (corrupted) training and validation sets. (`dataloaders.train.show_batch()`)

*Note: You may or may not actually see a mislabeled image here.*

In [None]:
# your code here
dataloaders.train.show_batch()

In [None]:
# your code here
dataloaders.valid.show_batch()

2. Compute the *accuracy* and *error rate* of this classifier on the (corrupted) validation set (`accuracy(interp.preds, interp.targs)`). Check that this number matches the last accuracy figure reported while training above. Multiply this by the number of images in the validation set to give the actual number of misclassified images.

*Hint*: you may need `WHATEVER.item()` to get a plain number instead of a `Tensor`.

In [None]:
# your code here
interp = ClassificationInterpretation.from_learner(learn)
print("Accuracy:", accuracy(interp.preds, interp.targs).item())
print("Error rate: ", error_rate(interp.preds, interp.targs).item())
print(f"Number of images incorrect: {round(error_rate(interp.preds, interp.targs).item() * corrupted_dataloaders.valid.n)} out of {corrupted_dataloaders.valid.n}")

3. Plot the confusion matrix on the (corrupted) validation set (see chapter 2).

In [None]:
# your code here
interp.plot_confusion_matrix()

4. Compute the accuracy on the (corrupted) *training* set. (Since "dataset 0" is the training set and "dataset 1" is the validation set, we can use `interp_train = ClassificationInterpretation.from_learner(learn, ds_idx=0)`)

In [None]:
interp_train = ClassificationInterpretation.from_learner(learn, ds_idx=0)
# your code here
print("Accuracy:", accuracy(interp_train.preds, interp_train.targs).item())

5. Plot the top 12 losses in the validation set.

In [None]:
interp.plot_top_losses(12)

## Analysis

1. In the (corrupted) validation set, **how many dogs were misclassified as cats? Vice versa?**

In [None]:
num_incorrectly_labeled_cat = ...
num_incorrectly_labeled_dog = ...

2. **If we had only looked at the accuracy on the (corrupted) training set, would we have *overestimated* or *underestimated* the validation set performance? By how much?**

*your answer here*

3. Find some images in the validation set that were "misclassified" by looking at the top losses. **What does the classifier do with images that were mislabeled?** Explain what "probability" means in this output.

*your answer here*