# Simple Model Comparison

This notebooks provides an overview for using and understanding simple model comparison check.

**Structure:**

- [What is the purpose of the check?](#purpose)
- [Generate data & model](#generate_data_model)
- [Run the check](#run_check)

## What is the purpose of the check? <a name='purpose'></a>

The simple model is designed to produce the best performance achievable using very simple rules. The goal of the simple model is to provide a baseline of minimal model performance for the given task, to which the user model may be compared. If the user model achieves less or a similar score to the simple model, this is an indicator for a possible problem with the model (e.g. it wasn't trained properly).

In the computer vision module, this checks applies only to classification problems.

The check has four possible strategies for selecting the behavior of the baseline simple model. By default the check uses the **prior** strategy, which can be overriden in the checks' parameters using `strategy`. Similiar to the [tabular simple model comparison check](../../../tabular/checks/performance/simple_model_comparison), there is no simple model which is more "correct" to use, each gives a different baseline to compare to, and you may experiment with the different types and see how it performs on your data.\
The available strategies are:
- **prior** (Default) - The probability vector always contains the empirical class prior distribution (i.e. the class distribution observed in the training set).
- **most_frequent** - The most frequent prediction is predicted. The probability vector is 1 for the most frequent prediction and 0 for the other predictions.
- **stratified** - The predictions are generated by sampling one-hot vectors from a multinomial distribution parametrized by the empirical class prior probabilities.
- **uniform** - Generates predictions uniformly at random from the list of unique classes observed in y, i.e. each class has equal probability.

## Generate data & model <a name="generate_data_model"></a>

In [1]:
from deepchecks.vision.base import VisionData
from deepchecks.vision.checks.performance import SimpleModelComparison

In [2]:
from deepchecks.vision.datasets.classification import mnist

mnist_model = mnist.load_model()
train_ds = mnist.load_dataset(train=True, object_type='VisionData')
test_ds = mnist.load_dataset(train=False, object_type='VisionData')

In [3]:
from deepchecks.vision.utils.classification_formatters import ClassificationPredictionFormatter
from torch import nn

pred_formatter = ClassificationPredictionFormatter(mnist.mnist_prediction_formatter)

## Run the check <a name="c_run_check"></a>

We will run the check with the **prior** model type. The check will use the default classification metrics - precision and recall. This can be overridden by providing an alternative scorer using the `alternative_metrics` parameter.

In [5]:
check = SimpleModelComparison(strategy='stratified')
result = check.run(train_ds, test_ds, mnist_model, prediction_formatter=pred_formatter)

In [6]:
result

### Observe the check's output

We can see in the results that the check calculates the score for each class in the dataset, and compares the scores between our model and the simple model.

In addition to the graphic output, the check also returns a value which includes all of the information that is needed for defining the conditions for validation. 

The value is a dataframe that contains the metrics' values for each class and dataset:

In [11]:
result.value.sort_values(by=['Class', 'Metric']).head(10)

Unnamed: 0,Metric,Class,Class Name,Value,Dataset,Number of samples
0,Precision,0,0,0.985801,Test,980
0,Precision,0,0,0.104146,Simple Model,980
10,Recall,0,0,0.991837,Test,980
10,Recall,0,0,0.103061,Simple Model,980
1,Precision,1,1,0.988616,Test,1135
1,Precision,1,1,0.112945,Simple Model,1135
11,Recall,1,1,0.994714,Test,1135
11,Recall,1,1,0.115419,Simple Model,1135
2,Precision,2,2,0.989268,Test,1032
2,Precision,2,2,0.118557,Simple Model,1032
