📝 **Author:** Amirhossein Heydari - 📧 **Email:** amirhosseinheydari78@gmail.com - 📍 **Linktree:** [linktr.ee/mr_pylin](https://linktr.ee/mr_pylin)

---

# Dependencies

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import classification_report
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchinfo import summary
from torchmetrics import Accuracy, ConfusionMatrix
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2

In [2]:
# set a seed for deterministic results
random_state = 42
torch.manual_seed(random_state)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# check if cuda is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Pre-Processing

# Load Dataset

In [4]:
# initial transforms
transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True)
    ]
)

# load the CIFAR-10 dataset
trainset = CIFAR10(root='./dataset', train=True , download=True, transform=transforms)
testset  = CIFAR10(root='./dataset', train=False, download=True, transform=transforms)

# log
print('trainset:')
print(f"    -> trainset.data.shape    : {trainset.data.shape}")
print(f"    -> trainset.data.dtype    : {trainset.data.dtype}")
print(f"    -> type(trainset.data)    : {type(trainset.data)}")
print(f"    -> type(trainset.targets) : {type(trainset.targets)}")
print('-' * 50)
print('testset:')
print(f"    -> testset.data.shape     : {testset.data.shape}")
print(f"    -> testset.data.dtype     : {testset.data.dtype}")
print(f"    -> type(testset.data)     : {type(testset.data)}")
print(f"    -> type(testset.targets)  : {type(testset.targets)}")
print('-' * 50)
print(f"classes: {trainset.class_to_idx}")
print(f"trainset distribution: {np.unique(trainset.targets, return_counts=True)[1]}")
print(f"testset  distribution: {np.unique(testset.targets, return_counts=True)[1]}")

Files already downloaded and verified
Files already downloaded and verified
trainset:
    -> trainset.data.shape    : (50000, 32, 32, 3)
    -> trainset.data.dtype    : uint8
    -> type(trainset.data)    : <class 'numpy.ndarray'>
    -> type(trainset.targets) : <class 'list'>
--------------------------------------------------
testset:
    -> testset.data.shape     : (10000, 32, 32, 3)
    -> testset.data.dtype     : uint8
    -> type(testset.data)     : <class 'numpy.ndarray'>
    -> type(testset.targets)  : <class 'list'>
--------------------------------------------------
classes: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
trainset distribution: [5000 5000 5000 5000 5000 5000 5000 5000 5000 5000]
testset  distribution: [1000 1000 1000 1000 1000 1000 1000 1000 1000 1000]


In [None]:
# plot
fig, axs = plt.subplots(nrows=4, ncols=8, figsize=(12, 6), layout='compressed')
for i in range(4):
    for j in range(8):
        axs[i, j].imshow(trainset.data[i * 8 + j], cmap='gray')
        axs[i, j].set_title(trainset.classes[trainset.targets[i * 8 + j]])
        axs[i, j].axis('off')
plt.show()

## Split trainset into [trainset, validationset]

In [6]:
# random split
trainset, validationset = random_split(trainset, [.9, .1])

# log
print('trainset:')
print(f"    -> len(trainset)       : {len(trainset)}")
print(f"    -> trainset[0][0]      : {trainset[0][0].shape}")
print(f"    -> trainset[0][1]      : {trainset[0][1]}\n")
print('validationset:')
print(f"    -> len(validationset)  : {len(validationset)}")
print(f"    -> validationset[0][0] : {validationset[0][0].shape}")
print(f"    -> validationset[0][1] : {validationset[0][1]}\n")
print('testset:')
print(f"    -> len(testset)        : {len(testset)}")
print(f"    -> testset[0][0]       : {testset[0][0].shape}")
print(f"    -> testset[0][1]       : {testset[0][1]}")

trainset:
    -> len(trainset)       : 45000
    -> trainset[0][0]      : torch.Size([3, 32, 32])
    -> trainset[0][1]      : 6

validationset:
    -> len(validationset)  : 5000
    -> validationset[0][0] : torch.Size([3, 32, 32])
    -> validationset[0][1] : 7

testset:
    -> len(testset)        : 10000
    -> testset[0][0]       : torch.Size([3, 32, 32])
    -> testset[0][1]       : 3


## Normalization

In [7]:
# create a temporary DataLoader for the trainset
temp_trainloader = DataLoader(trainset, batch_size=len(trainset))

# get the whole data
temp_dataset = next(iter(temp_trainloader))

# calculate the mean and standard deviation [PER CHANNEL]
train_mean = temp_dataset[0].mean(axis=(0, 2, 3))  # [0.4917, 0.4823, 0.4467]
train_std = temp_dataset[0].std(axis=(0, 2, 3))  # [0.2471, 0.2435, 0.2616]

del temp_trainloader
del temp_dataset

# log
print(f"train mean per channel : {train_mean}")
print(f"train std  per channel : {train_std}")

train mean per channel : tensor([0.4917, 0.4823, 0.4467])
train std  per channel : tensor([0.2471, 0.2435, 0.2616])


## Transform
   - on-the-fly data augmentation
   - Disadvantage:
      - same transform applies to the same data in each epoch
   - Advantage:
      - Reduced Memory Usage, Regularization & Data Diversity [random transforms e.g. RancomCrop]

In [8]:
transforms

Compose(
      ToImage()
      ToDtype(scale=True)
)

In [9]:
transforms.transforms.append(v2.Normalize(mean=train_mean, std=train_std))

# log
print(f"trainset.dataset.transforms:\n{trainset.dataset.transforms}\n")
print(f"validationset.dataset.transforms:\n{validationset.dataset.transforms}\n")
print(f"testset.transforms:\n{testset.transforms}")

trainset.dataset.transforms:
StandardTransform
Transform: Compose(
                 ToImage()
                 ToDtype(scale=True)
                 Normalize(mean=[tensor(0.4917), tensor(0.4823), tensor(0.4467)], std=[tensor(0.2471), tensor(0.2435), tensor(0.2616)], inplace=False)
           )

validationset.dataset.transforms:
StandardTransform
Transform: Compose(
                 ToImage()
                 ToDtype(scale=True)
                 Normalize(mean=[tensor(0.4917), tensor(0.4823), tensor(0.4467)], std=[tensor(0.2471), tensor(0.2435), tensor(0.2616)], inplace=False)
           )

testset.transforms:
StandardTransform
Transform: Compose(
                 ToImage()
                 ToDtype(scale=True)
                 Normalize(mean=[tensor(0.4917), tensor(0.4823), tensor(0.4467)], std=[tensor(0.2471), tensor(0.2435), tensor(0.2616)], inplace=False)
           )


In [10]:
# log
print("before applying transform:")
print(f"    -> type(testset.data[0]) : {type(testset.data[0])}")
print(f"    -> testset.data[0].dtype : {testset.data[0].dtype}")
print(f"    -> testset.data[0].shape : {testset.data[0].shape}")
print('-' * 50)
print("after applying transform:")
print(f"    -> type(testset[0][0])   : {type(testset[0][0])}")
print(f"    -> testset[0][0].dtype   : {testset[0][0].dtype}")
print(f"    -> testset[0][0].shape   : {testset[0][0].shape}")

before applying transform:
    -> type(testset.data[0]) : <class 'numpy.ndarray'>
    -> testset.data[0].dtype : uint8
    -> testset.data[0].shape : (32, 32, 3)
--------------------------------------------------
after applying transform:
    -> type(testset[0][0])   : <class 'torchvision.tv_tensors._image.Image'>
    -> testset[0][0].dtype   : torch.float32
    -> testset[0][0].shape   : torch.Size([3, 32, 32])


## DataLoader

In [11]:
batch_size = 64

trainloader      = DataLoader(dataset=trainset     , batch_size=batch_size, shuffle=True , num_workers=2)
validationloader = DataLoader(dataset=validationset, batch_size=batch_size, shuffle=False, num_workers=2)
testloader       = DataLoader(dataset=testset      , batch_size=batch_size, shuffle=False, num_workers=2)

In [12]:
first_train_batch = next(iter(trainloader))
first_validation_batch = next(iter(validationloader))
first_test_batch = next(iter(testloader))

print(f"trainloader      first batch     -> x.shape: {first_train_batch[0].shape} - y.shape: {first_train_batch[1].shape} - x.dtype: {first_train_batch[0].dtype} - y.dtype: {first_train_batch[1].dtype}")
print(f"validationloader first batch     -> x.shape: {first_validation_batch[0].shape} - y.shape: {first_validation_batch[1].shape} - x.dtype: {first_validation_batch[0].dtype} - y.dtype: {first_validation_batch[1].dtype}")
print(f"testloader       first batch     -> x.shape: {first_test_batch[0].shape} - y.shape: {first_test_batch[1].shape} - x.dtype: {first_test_batch[0].dtype} - y.dtype: {first_test_batch[1].dtype}")
print(f"trainloader      last batch-size -> {len(trainset) % batch_size}")
print(f"validationloader last batch-size -> {len(validationset) % batch_size}")
print(f"testloader       last batch-size -> {len(testset) % batch_size}")

trainloader      first batch     -> x.shape: torch.Size([64, 3, 32, 32]) - y.shape: torch.Size([64]) - x.dtype: torch.float32 - y.dtype: torch.int64
validationloader first batch     -> x.shape: torch.Size([64, 3, 32, 32]) - y.shape: torch.Size([64]) - x.dtype: torch.float32 - y.dtype: torch.int64
testloader       first batch     -> x.shape: torch.Size([64, 3, 32, 32]) - y.shape: torch.Size([64]) - x.dtype: torch.float32 - y.dtype: torch.int64
trainloader      last batch-size -> 8
validationloader last batch-size -> 8
testloader       last batch-size -> 16


# Network Structure: Convolutional Neural Networks
   - Sequential Model
      - Use torch.nn.Sequential to create a sequence of layers or modules
      - [pytorch.org/docs/stable/generated/torch.nn.Sequential.html](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)
   - Functional Model
      - for stateless operations like activation functions, loss functions, and other operations within the forward method of custom modules or in custom functions
      - [pytorch.org/docs/stable/nn.functional.html](https://pytorch.org/docs/stable/nn.functional.html)
   - Mixed Model

**Notes**:
   - `torch.nn.Conv2d`
      - loss function : 
         - multi-class classification : `torch.nn.CrossEntropyLoss` = `torch.nn.LogSoftmax` + `torch.nn.NLLLoss`
         - [pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
         - [pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html)
      - activation function for the last layer:
         - when using `torch.nn.CrossEntropyLoss` as a loss function, the output layer doesn't need an activation function
         - `torch.nn.CrossEntropyLoss` calculates `torch.nn.LogSoftmax` and `torch.nn.NLLLoss` internally.
         - [pytorch.org/docs/stable/generated/torch.nn.Softmax.html](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html)
         - [pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html](https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html)
      - Weights
         - Initialized based on a scheme similar to Kaiming/He initialization
         - Uniform Distribution [default]: $W \sim \mathcal{U}\left(-\sqrt{\frac{6}{n_{\text{in}}}}, \sqrt{\frac{6}{n_{\text{in}}}}\right)$
         - Normal Distribution: $W \sim \mathcal{N}\left(0, \frac{2}{n_{\text{in}}}\right)$
      - Biases:
         - Initialized to zero
      - [pytorch.org/docs/stable/nn.init.html](https://pytorch.org/docs/stable/nn.init.html)
      - Paper: [Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015).](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf)

**Playground**:
   - [poloclub.github.io/cnn-explainer](https://poloclub.github.io/cnn-explainer/)
   - [convnetplayground.fastforwardlabs.com](https://convnetplayground.fastforwardlabs.com/)
   - [alexlenail.me/NN-SVG](https://alexlenail.me/NN-SVG/)

<figure style="text-align: center;">
    <img src="../../assets/images/original/convolutional-neural-network.svg" alt="convolutional-neural-network.svg" style="width: 100%;">
    <figcaption>Convolutional Neural Networks Model</figcaption>
</figure>

<table style="margin-left:auto;margin-right:auto;text-align:center;">
  <thead>
    <tr>
      <th colspan="4" style="text-align:center;">Feature Extraction</th>
      <th colspan="4" style="text-align:center;">Classification</th>
    </tr>
    <tr>
      <th colspan="2">Convolution_1 parameters</th>
      <th colspan="2">Convolution_2 parameters</th>
      <th colspan="2">hidden<sub>1</sub> parameters</th>
      <th colspan="2">logits parameters</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Weights</td>
      <td>Biases</td>
      <td>Weights</td>
      <td>Biases</td>
      <td>Weights</td>
      <td>Biases</td>
      <td>Weights</td>
      <td>Biases</td>
    </tr>
    <tr>
      <td>(1 x 3 × 3) × A</td>
      <td>A</td>
      <td>(A x 3 × 3) × B</td>
      <td>B</td>
      <td>C × D</td>
      <td>D</td>
      <td>D × E</td>
      <td>E</td>
    </tr>
  </tbody>
  <tfoot>
    <tr>
      <td colspan="2">(1 × 3 × 3 + 1) × A</td>
      <td colspan="2">(A × 3 × 3 + 1) × B</td>
      <td colspan="2">(C + 1) × D</td>
      <td colspan="2">(D + 1) × E</td>
    </tr>
  </tfoot>
</table>


In [13]:
class CIFAR10Model(nn.Module):
    def __init__(self, in_channels, output_dim):
        super(CIFAR10Model, self).__init__()
        self.feature_extractor = nn.Sequential(

            # 3x32x32
            nn.Conv2d(in_channels, out_channels=32, kernel_size=3),
            nn.BatchNorm2d(32),  # StandardScaler along channel axis
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            # 32x15x15

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            # 64x6x6

            nn.AdaptiveAvgPool2d(output_size=(1, 1))
            # 64x1x1
        )

        self.flatten = nn.Flatten(start_dim=1)

        self.classifier = nn.Sequential(
            nn.Linear(64, output_dim),
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

In [14]:
in_channels = trainset[0][0].shape[0]
output_dim = len(trainset.dataset.classes)

model = CIFAR10Model(in_channels, output_dim)
model.to(device)

CIFAR10Model(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): AdaptiveAvgPool2d(output_size=(1, 1))
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (classifier): Sequential(
    (0): Linear(in_features=64, out_features=10, bias=True)
  )
)

In [15]:
summary(model, input_size=(batch_size, *testset.data.transpose(0, 3, 1, 2).shape[1:]))

Layer (type:depth-idx)                   Output Shape              Param #
CIFAR10Model                             [64, 10]                  --
├─Sequential: 1-1                        [64, 64, 1, 1]            --
│    └─Conv2d: 2-1                       [64, 32, 30, 30]          896
│    └─BatchNorm2d: 2-2                  [64, 32, 30, 30]          64
│    └─ReLU: 2-3                         [64, 32, 30, 30]          --
│    └─MaxPool2d: 2-4                    [64, 32, 15, 15]          --
│    └─Conv2d: 2-5                       [64, 64, 13, 13]          18,496
│    └─BatchNorm2d: 2-6                  [64, 64, 13, 13]          128
│    └─ReLU: 2-7                         [64, 64, 13, 13]          --
│    └─MaxPool2d: 2-8                    [64, 64, 6, 6]            --
│    └─AdaptiveAvgPool2d: 2-9            [64, 64, 1, 1]            --
├─Flatten: 1-2                           [64, 64]                  --
├─Sequential: 1-3                        [64, 10]                  --
│    └─Li

# Set up remaining Hyper-Parameters

In [16]:
lr = 0.01
criterion = CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=lr)
num_epochs = 10

# Train & Validation Loop

In [17]:
train_acc_per_epoch  = []
train_loss_per_epoch = []
val_acc_per_epoch    = []
val_loss_per_epoch   = []

In [18]:
train_acc = Accuracy(task='multiclass', num_classes=len(testset.classes), top_k=1).to(device)
val_acc   = Accuracy(task='multiclass', num_classes=len(testset.classes), top_k=1).to(device)

In [19]:
for epoch in range(num_epochs):

# train loop
    model.train()
    train_loss = 0

    for x, y in trainloader:

        # send data to GPU
        x, y_true = x.to(device), y.to(device)

        # forward
        y_pred = model(x)
        loss = criterion(y_pred, y_true)

        # backward
        loss.backward()

        # update parameters
        optimizer.step()
        optimizer.zero_grad()

        # log loss & accuracy
        train_loss += loss.item() * len(x)
        train_acc.update(y_pred, y_true)

    # store intermediate loss & accuracy
    train_loss_per_epoch.append(train_loss / len(trainset))
    train_acc_per_epoch.append(train_acc.compute().item())
    train_acc.reset()


# validation loop
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for x, y in validationloader:

            # send data to GPU
            x, y_true = x.to(device), y.to(device)

            # forward
            y_pred = model(x)
            loss = criterion(y_pred, y_true)

            # log loss & accuracy
            val_loss += loss.item() * len(x)
            val_acc.update(y_pred, y_true)

    # store intermediate loss & accuracy
    val_loss_per_epoch.append(val_loss / len(validationset))
    val_acc_per_epoch.append(val_acc.compute().item())
    val_acc.reset()

    # log
    print(f"epoch {epoch:>1}  ->  train[loss: {train_loss_per_epoch[epoch]:.5f} - acc: {train_acc_per_epoch[epoch]:.2f}] | validation[loss: {val_loss_per_epoch[epoch]:.5f} - acc: {val_acc_per_epoch[epoch]:.2f}]")

epoch 0  ->  train[loss: 1.53168 - acc: 0.44] | validation[loss: 1.49292 - acc: 0.48]
epoch 1  ->  train[loss: 1.28642 - acc: 0.54] | validation[loss: 1.23919 - acc: 0.57]
epoch 2  ->  train[loss: 1.20031 - acc: 0.58] | validation[loss: 1.59825 - acc: 0.46]
epoch 3  ->  train[loss: 1.14244 - acc: 0.60] | validation[loss: 1.16952 - acc: 0.60]
epoch 4  ->  train[loss: 1.08944 - acc: 0.62] | validation[loss: 1.22298 - acc: 0.57]
epoch 5  ->  train[loss: 1.04375 - acc: 0.64] | validation[loss: 1.24219 - acc: 0.58]
epoch 6  ->  train[loss: 1.01026 - acc: 0.65] | validation[loss: 1.10631 - acc: 0.61]
epoch 7  ->  train[loss: 0.98565 - acc: 0.66] | validation[loss: 1.02241 - acc: 0.64]
epoch 8  ->  train[loss: 0.95677 - acc: 0.67] | validation[loss: 1.19718 - acc: 0.62]
epoch 9  ->  train[loss: 0.93040 - acc: 0.68] | validation[loss: 1.10159 - acc: 0.62]


## Model Analysis

In [None]:
# plot
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), layout='compressed')
axs[0].plot(train_loss_per_epoch, label="Train loss")
axs[0].plot(val_loss_per_epoch, label="Validation loss")
axs[0].set(title="Loss over time", xlabel='Epoch', ylabel='Loss')
axs[0].legend(loc='best', fancybox=True, shadow=True)
axs[1].plot(train_acc_per_epoch, label="Train accuracy")
axs[1].plot(val_acc_per_epoch, label="Validation accuracy")
axs[1].set(title="Accuracy over time", xlabel='Epoch', ylabel='Accuracy')
axs[1].legend(loc='best', fancybox=True, shadow=True)
plt.show()

# Test Loop

In [21]:
test_acc = Accuracy(task='multiclass', num_classes=len(testset.classes), top_k=1).to(device)

In [22]:
model.eval()
test_loss = 0
predictions = []
targets = []

with torch.no_grad():
    for x, y in testloader:

        # send data to GPU
        x, y_true = x.to(device), y.to(device)

        # forward
        y_pred = model(x)
        loss = criterion(y_pred, y_true)

        # log loss & accuracy
        test_loss += loss.item() * len(x)
        test_acc.update(y_pred, y_true)

        predictions.extend(y_pred.argmax(dim=1).cpu())
        targets.extend(y_true.cpu())

# log
print(f"test[loss: {test_loss / len(testset):.5f} - acc: {test_acc.compute().item():.2f}]")

test[loss: 1.07869 - acc: 0.62]


## Metrics
   - Loss
   - Accuracy
   - Recall
   - Precision
   - F1-Score
   - Confusion Matrix
   - Area Under the ROC Curve (AUC-ROC)
   - Area Under the Precision-Recall Curve (AUC-PR)
   - ...

**Docs**:
   - [lightning.ai/docs/torchmetrics/stable/all-metrics.html](https://lightning.ai/docs/torchmetrics/stable/all-metrics.html)
   - [scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html)

In [23]:
# classification report
print(classification_report(targets, predictions))

              precision    recall  f1-score   support

           0       0.76      0.59      0.66      1000
           1       0.85      0.71      0.77      1000
           2       0.61      0.45      0.52      1000
           3       0.52      0.39      0.44      1000
           4       0.36      0.87      0.51      1000
           5       0.76      0.38      0.51      1000
           6       0.77      0.60      0.67      1000
           7       0.62      0.66      0.64      1000
           8       0.79      0.76      0.78      1000
           9       0.70      0.82      0.75      1000

    accuracy                           0.62     10000
   macro avg       0.67      0.62      0.63     10000
weighted avg       0.67      0.62      0.63     10000



In [28]:
# confusion matrix
metric = ConfusionMatrix(task='multiclass', num_classes=10)
confusion_matrix = metric(torch.tensor(predictions), torch.tensor(targets))

# log
print(confusion_matrix)

# plot
fig, ax = plt.subplots(figsize=(8, 8))
metric.plot(ax=ax)
plt.show()

tensor([[587,  31,  63,  16,  85,   4,  10,  41, 109,  54],
        [ 18, 710,  14,   8,  37,   1,   4,  15,  22, 171],
        [ 49,   3, 453,  55, 270,  17,  60,  72,  12,   9],
        [ 11,   3,  57, 388, 306,  70,  57,  68,  11,  29],
        [  8,   2,  25,  24, 870,   5,  26,  33,   3,   4],
        [  8,   5,  42, 169, 239, 383,  17, 119,  10,   8],
        [  3,   1,  53,  37, 274,   5, 601,  16,   5,   5],
        [  9,   1,  21,  26, 253,  20,   2, 656,   1,  11],
        [ 65,  31,   9,  15,  46,   0,   3,  17, 759,  55],
        [ 11,  46,   7,  11,  56,   1,   3,  25,  24, 816]])


# Prediction

In [25]:
def predict(model: nn.Module, data: np.ndarray, classes: list, transform: v2._container.Compose = None) -> torch.Tensor:

    # add batch dimension to a single data
    if len(data.shape) == 3:
        data = np.expand_dims(data, axis=0)

    # apply the transform
    if transform:
        data = torch.stack([transform(sample) for sample in data])

    # predict
    model.eval()
    with torch.no_grad():

        # send data to GPU
        data = data.to(device)

        # forward
        y_pred = model(data).argmax(dim=1).cpu()

        # idx to labels
        y_pred = np.array(classes)[y_pred]

    return y_pred

In [26]:
# some raw data
raw_data = CIFAR10(root='./dataset', train=False, download=True, transform=None).data[:32]

# predict
y_pred = predict(model, data=raw_data, classes=testset.classes, transform=transforms)

# log
print(f"predictions:\n{y_pred}")

Files already downloaded and verified
predictions:
['cat' 'ship' 'ship' 'airplane' 'frog' 'frog' 'automobile' 'deer' 'deer'
 'automobile' 'ship' 'truck' 'cat' 'horse' 'truck' 'ship' 'dog' 'horse'
 'truck' 'frog' 'horse' 'airplane' 'deer' 'truck' 'deer' 'deer' 'cat'
 'bird' 'truck' 'deer' 'frog' 'dog']


In [None]:
# plot
fig, axs = plt.subplots(nrows=4, ncols=8, figsize=(12, 6), layout='compressed')
for i in range(4):
    for j in range(8):
        axs[i, j].imshow(raw_data[i * 8 + j], cmap='gray')
        axs[i, j].set_title(predict(model, raw_data[i * 8 + j], testset.classes, transform=transforms))
        axs[i, j].axis('off')
plt.show()