# Confision Matrix

In [1]:
%load_ext autoreload
%autoreload 2

## Imports

In [2]:
from deepchecks.vision.dataset import VisionData
from torchvision import models
import torchvision
import torch
from torch import nn
from torchvision.transforms import ToTensor
import copy

In [3]:
from deepchecks.vision.checks.performance import ConfusionMatrixReport

In [38]:

from deepchecks.vision.datasets.classification import mnist
def simple_nn():
    torch.manual_seed(42)
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
            self.flatten = nn.Flatten()
            self.linear_relu_stack = nn.Sequential(
                nn.Linear(28 * 28, 512),
                nn.ReLU(),
                nn.Linear(512, 512),
                nn.ReLU(),
                nn.Linear(512, 10)
            )

        def forward(self, x):
            x = self.flatten(x)
            logits = self.linear_relu_stack(x)
            return logits

    model = NeuralNetwork().to('cpu')
    return model

In [39]:
def trained_mnist(simple_nn, mnist_data_loader_train):
    torch.manual_seed(42)
    simple_nn = copy.deepcopy(simple_nn)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(simple_nn.parameters(), lr=1e-3)
    size = len(mnist_data_loader_train.dataset)
    # Training 1 epoch
    simple_nn.train()
    for batch, (X, y) in enumerate(mnist_data_loader_train):
        X, y = X.to('cpu'), y.to('cpu')

        # Compute prediction error
        pred = simple_nn(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return simple_nn


In [40]:
train_ds = mnist.load_dataset(train=True, object_type='VisionData')
test_ds = mnist.load_dataset(train=False, object_type='VisionData')
model = trained_mnist(simple_nn(), train_ds.get_data_loader())


loss: 2.287713  [    0/60000]
loss: 2.271768  [ 6400/60000]
loss: 2.227654  [12800/60000]
loss: 2.213990  [19200/60000]
loss: 2.159032  [25600/60000]
loss: 2.118144  [32000/60000]
loss: 2.020085  [38400/60000]
loss: 1.975659  [44800/60000]
loss: 1.823154  [51200/60000]
loss: 1.817346  [57600/60000]


In [46]:
from deepchecks.vision.utils.classification_formatters import ClassificationPredictionFormatter
check = ConfusionMatrixReport()
check.run(train_ds, model, prediction_formatter=ClassificationPredictionFormatter(mnist.mnist_prediction_formatter))

# confusion matrix report object detection

In [31]:
from deepchecks.vision.datasets.detection import coco
import numpy as np

In [32]:
yolo = coco.load_model(pretrained=True)


In [33]:
coco_dataloader = coco.load_dataset()

In [34]:
from deepchecks.vision.dataset import VisionData

In [35]:
from deepchecks.vision.utils.detection_formatters import DetectionLabelFormatter, DetectionPredictionFormatter
train_ds = VisionData(coco_dataloader, label_transformer=DetectionLabelFormatter(coco.yolo_label_formatter), num_classes=80)

In [37]:
check = ConfusionMatrixReport( categories_to_display=100)
check.run(train_ds, yolo, prediction_formatter=DetectionPredictionFormatter(coco.yolo_prediction_formatter))

[tensor([[4.50000e+01, 1.08000e+00, 1.87690e+02, 6.11590e+02, 2.85840e+02],
        [4.50000e+01, 3.11730e+02, 4.31016e+00, 3.19280e+02, 2.28680e+02],
        [5.00000e+01, 2.49600e+02, 2.29270e+02, 3.16240e+02, 2.45080e+02],
        [4.50000e+01, 3.05176e-04, 1.35101e+01, 4.34480e+02, 3.75120e+02],
        [4.90000e+01, 3.76200e+02, 4.03601e+01, 7.55501e+01, 4.65298e+01],
        [4.90000e+01, 4.65780e+02, 3.89700e+01, 5.80698e+01, 4.66699e+01],
        [4.90000e+01, 3.85700e+02, 7.36598e+01, 8.40198e+01, 7.05101e+01],
        [4.90000e+01, 3.64050e+02, 2.49024e+00, 9.47603e+01, 7.10698e+01]], dtype=torch.float64), tensor([[ 23.00000, 385.52990,  60.03001, 214.97025, 297.16013],
        [ 23.00000,  53.01024, 356.49000, 132.03008,  55.19001]], dtype=torch.float64), tensor([[ 58.00000, 204.86014,  31.01973, 254.88001, 324.12012],
        [ 75.00000, 237.56031, 155.80998, 166.39999, 195.25017]], dtype=torch.float64), tensor([[ 22.00000,   0.95999,  20.06001, 441.23009, 379.15014]], dtyp


User provided device_type of 'cuda', but CUDA is not available. Disabling



[tensor([[3.18191e+02, 0.00000e+00, 3.13718e+02, 2.45932e+02, 7.22805e-01, 4.50000e+01],
        [2.52497e+02, 2.25776e+02, 3.16938e+02, 2.45335e+02, 4.50116e-01, 5.00000e+01],
        [3.24759e+02, 4.08760e+02, 1.00446e+02, 7.01823e+01, 4.25880e-01, 5.00000e+01],
        [1.89853e+00, 2.26182e+01, 6.35830e+02, 4.53694e+02, 3.25460e-01, 6.00000e+01]]), tensor([[365.84082,  65.22775, 235.13348, 295.86938,   0.79881,  23.00000],
        [ 49.93516, 349.19617, 127.79150,  68.35443,   0.61604,  23.00000]]), tensor([[241.90906, 160.25955, 166.37646, 189.53174,   0.76393,  75.00000],
        [203.40369,  37.17995, 258.65381, 307.47241,   0.46351,  58.00000]]), tensor([[  2.97177,  17.65305, 437.78311, 381.26868,   0.84289,  22.00000]]), tensor([[176.28598, 148.78886, 243.00616, 491.21112,   0.84865,   0.00000],
        [  0.00000,  41.53351, 448.61981, 409.94836,   0.64747,  25.00000]]), tensor([[2.13479e+02, 4.47612e+01, 3.79123e+02, 2.91213e+02, 2.86351e-01, 1.60000e+01]]), tensor([[1.9945

[tensor([[ 22.00000,  11.97991, 315.59006, 349.08018, 324.40994],
        [ 22.00000,  40.45996, 192.98016, 273.61989, 139.16992],
        [ 22.00000, 239.46011,  93.63008,  99.18997,  63.72992]], dtype=torch.float64), tensor([[ 39.00000, 389.84033, 183.91992,   8.01984,  26.86992],
        [ 39.00000, 374.40001, 189.63983,   6.86976,  20.92992],
        [ 39.00000, 367.79007, 184.72008,   7.05024,  27.54000],
        [ 39.00000, 383.15967, 187.56000,   7.21024,  24.24000],
        [ 39.00000, 433.78975, 280.35985,  18.32000,  31.36992],
        [ 72.00000, 428.69982, 171.06983,  96.72000, 138.03984],
        [ 39.00000, 463.10015, 282.68017,   9.92000,  36.55008],
        [ 39.00000, 399.87008, 161.68008,   9.40032,  17.10000],
        [ 39.00000, 373.60994, 135.79992,  10.49024,  15.15984],
        [ 39.00000, 404.56031, 132.61992,   9.77984,  15.75024],
        [ 39.00000, 415.11010, 131.22024,  10.65024,  16.86000],
        [ 56.00000, 159.04000, 409.15993, 151.36000,  70.83984],
 

[tensor([[ 18.99690, 315.96600, 338.59674, 319.03772,   0.91007,  22.00000],
        [237.51813,  95.12727, 100.65485,  61.35946,   0.90833,  22.00000],
        [ 34.78110, 190.58105, 280.92059, 144.51074,   0.89944,  22.00000]]), tensor([[1.57956e+02, 4.08457e+02, 1.49840e+02, 6.98147e+01, 8.86408e-01, 5.60000e+01],
        [3.75861e+02, 2.23736e+02, 5.41929e+01, 3.50314e+01, 7.08006e-01, 6.80000e+01],
        [4.31403e+02, 1.68395e+02, 9.17708e+01, 1.35150e+02, 5.53419e-01, 7.20000e+01],
        [4.63243e+02, 2.81851e+02, 1.01569e+01, 3.53643e+01, 4.65606e-01, 3.90000e+01],
        [3.89901e+02, 1.84314e+02, 9.39490e+00, 2.61179e+01, 4.64250e-01, 3.90000e+01],
        [7.19516e+01, 1.91774e+02, 2.10062e+01, 4.03702e+01, 4.30912e-01, 4.00000e+01],
        [1.46526e+02, 1.90632e+02, 1.84500e+01, 3.30047e+01, 4.12285e-01, 4.00000e+01],
        [8.75708e+01, 1.91687e+02, 2.05163e+01, 3.89589e+01, 3.93090e-01, 4.00000e+01],
        [6.22750e+01, 1.89819e+02, 1.59653e+01, 4.27002e+01, 3.86