# Classifying Snakes

This notebook demonstrates how the `deepchecks` can help you evaluate a trained computer vision classification model.

This scenario is something that happens to computer vision researchers in real life. Often when given a dataset in a production environment, an initial model will be trained and then evaluated against common benchmarks, only to achieve relative "success". Such benchmarks, when dealing with classification, are usually *accuracy*, *precision* and *recall*.
The main problem is that those often hide serious problems with our data. In academic scenarios, the test set is sampled from the general set and that's completely fine. When working on real products, usually a "test" set should comprise of a much larger data distribution that isn't available to us until a product has actually been shipped to users. In such scenarios we can quickly come to the conclusion that our model isn't robust enough for various real-world cases which simply don't exist in our data. Examples can be a car dataset which doesn't include enough color variations, or a pedestrian detection model which is based only on images taken in a specific time of the day where's there's less (or more) sunlight.

DeepCheck's computer vision toolkits consist of a RobustnessCheck tool which allows us to take a trained model and try to find such corner cases. In order to do that, it actually uses commonly (and less commonly) used image augmentations as ready-made image transformations.

Image augmentations are a general group of image operators that modify a given image used as a data point when training a model. Augmenting a dataset effectively makes it larger by adding slightly different images. Not every image augmentation is suited for every task. A common example is horizontal shifts - which are often used, since most objects look the same when mirrored. This isn't true for a vertical shift, though, and because of that it's less commonly used as a default, despite the fact that it's actually quite useful for a *lot* of datasets (just about image taken from above).

In the RobustnessCheck, as we'll soon see, we take a given trained model and run the test suite to visualize the image transformations (e.g. augmentations) the trained model is less "robust" to.

(Note: The data was taken from Kaggle's [Pre-processed Snake Images](https://www.kaggle.com/sameeharahman/preprocessed-snake-images). It's license allows for commercial use. The model used for this example was trained by us and is also open for any sort of use. The train/val split is ours.)

**Installing requirements**

In [None]:
import sys

import matplotlib.pyplot as plt
!{sys.executable} -m pip install deepchecks --quiet

## Loading the data

OK, let's take a look at the data!

In [None]:
from deepchecks.vision.datasets.classification.snakes import load_data

When dealing with images, a dataset usually contains tuples of *image and label*. Let's fetch the first image and take a close look.

In [25]:
dataset = load_data()
datapoint = dataset[0]
len(datapoint)
datapoint[0].shape

(384, 384, 3)

So, images in the dataset are 384x384. Let's visualize the first two images.

In [None]:

from plotly.subplots import make_subplots
import plotly.graph_objects as go
from skimage import io
fig = make_subplots(
    rows=1, cols=2)
fig.add_trace(go.Image(z=dataset[0][0]), 1, 1)
fig.add_trace(go.Image(z=dataset[1][0]), 1, 2)

Note that since we didn't define a *transformer* for the dataset yet, actual dataset output is of raw images. As we know, when working with DL models, the actual output of a dataset/dataloader is already processed to fit our architecture. Let's do that:

In [23]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
transforms = A.Compose([
    A.SmallestMaxSize(max_size=256),
    A.CenterCrop(height=224, width=224),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2(),
])
dataset.transforms = transforms

Let's try again:

In [24]:
datapoint = dataset[0]
len(datapoint)
datapoint[0].shape

torch.Size([3, 224, 224])

Okay, so now we have the common 224x224 normalized center crop used in common CV applications. Note that so far we are using the regulat torch Dataset object without any added bells and whistles.

In [27]:
print(len(dataset))
dataset.classes

3477


['class-0', 'class-1', 'class-2', 'class-3', 'class-4']

So, our dataset has 3477 images so far and 5 classes.
The reason we have gathered here is to see how the DeepChecks library can help us to evaluate our model. Let's load the DeepChecks performance check.

In [None]:
## Checking Model Performance with DeepChecks