In [None]:
%matplotlib inline

In [84]:
from pathlib import Path
from typing import Final

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as nnf
from torchvision import datasets, transforms

In [71]:
data_path = Path("../cifar_data")

class_names: list[str] = [
    'airplane','automobile','bird','cat','deer',
    'dog','frog','horse','ship','truck'
]

# Transform statistics taken from https://stackoverflow.com/a/69750247
cifar10_preprocessor = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]
)

cifar10_train = datasets.CIFAR10(
    data_path,
    train=True,
    download=True,
    transform=cifar10_preprocessor
)

cifar10_val = datasets.CIFAR10(
    data_path,
    train=False,
    download=True,
    transform=cifar10_preprocessor
)


Files already downloaded and verified
Files already downloaded and verified


Since I got the normalization values from the Internet, I should verify that these statistics are accurate. I will create numpy batch arrays and take the mean and std along the batch and 32x32 pixel axes.

In [36]:
# use np.concatenate to stick all the images together to form a (batch, 3, 32, 32) array
imgs = np.concatenate(
    np.asarray([[
        [
            cifar10_train[i][0][0].numpy(),
            cifar10_train[i][0][1].numpy(),
            cifar10_train[i][0][2].numpy()
        ]
        for i in range(len(cifar10_train))
    ]])
)

print(imgs.shape)

(50000, 3, 32, 32)


In [39]:
# calculate the mean along the (batch, pixel, pixel) axes
train_mean = np.mean(imgs, axis=(0, 2, 3))
print(train_mean)

[-0.00040607 -0.0005815  -0.00102856]


In [40]:
# calculate the std along the (batch, pixel, pixel) axes
train_std = np.std(imgs, axis=(0, 2, 3))
print(train_std)


[1.0001289  0.9999368  0.99995327]


Great! The data is normalized with zero mean and standard deviation of one (1).

In [72]:
label_map: dict[int, int] = {0: 0, 2: 1}

cifar2_class_names: list[str] = ['airplane', 'bird']
cifar10_to_2_indices: list[int] = [
    class_names.index(cifar2) for cifar2 in cifar2_class_names
]

cifar2 = [
    (img, label_map[label])
    for img, label in cifar10_train
    if label in cifar10_to_2_indices
]

cifar2_val = [
    (img, label_map[label])
    for img, label in cifar10_val
    if label in cifar10_to_2_indices
]


In [73]:
train_loader = DataLoader(cifar2, batch_size=64, shuffle=True)
val_loader = DataLoader(cifar2_val, batch_size=64, shuffle=False)

In [None]:
class Cifar2CNN(nn.Module):
    """"""

    def __init__(
        self,
        n_chans1=32
    ):
        """

        Parameters
        ----------
        n_chans1 : int
            Number of channels in the first layer

        """
        super().__init__()
        self.cifar_size: Final[int] = 32
        self.n_chans1: Final[int] = n_chans1
        self.cov_ker_size: Final[int] = 3
        self.cov_pad: Final[int] = 1
        # First convolutional layer (B, 3, 32, 32)
        self.conv1 = nn.Conv2d(
            3,
            n_chans1,
            self.cov_ker_size,
            padding=self.cov_pad
        )
        # Second convolutional layer, after applying pooling (B, n_chans1, 16, 16)
        self.conv2 = nn.Conv2d(
            n_chans1,
            n_chans1 // 2,
            self.cov_ker_size,
            padding=self.cov_pad
        )
        # Third convolutional layer, after applying pooling (B, n_chans1, 8, 8)
        self.conv3 = nn.Conv2d(
            n_chans1 // 2,
            n_chans1 // 2,
            self.cov_ker_size,
            padding=self.cov_pad
        )
        # Functional layer after convolutions and view/reshape (B, n_chans1 * 8 * 8, 32)
        self.fcn4 = nn.Linear(
            ((self.cifar_shape // 4) ** 2) * (n_chans1 // 2),
            32
        )
        # Functional layer after functional (B, 2)
        self.fcn5 = nn.Linear(32, 2)

    def forward(self, batch):
        """Propagate the batch forward through NN.

        Parameters
        ----------
        batch : torch.Tensor
            Batch of images

        Returns
        -------
        torch.Tensor
            Forward propagated tensor
        """
        out = nnf.max_pool2d(torch.tanh(self.conv1(batch)), 2)
        out = nnf.max_pool2d(torch.tanh(self.conv2(out)), 2)
        out = out.view(-1, self.fcn4.in_features)
        out = torch.tanh(self.fcn4(out))
        out = self.fcn5(out)
        return out

In [76]:
l = nn.Linear(8 * 8 * 32 // 2, 32)
l.out_features

32

In [77]:
c = nn.Conv2d(3, 32, kernel_size=3, padding=1)
c.out_channels

32

In [82]:
cifar10_train[0][0].shape

torch.Size([3, 32, 32])

In [83]:
cifar10_train[0][0].view(-1, 8 * 8 * 16).shape

torch.Size([3, 1024])