<a href="https://colab.research.google.com/github/camillan/computer_vision/blob/main/cnn_cifar_image_recognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

# convert images to tensors and normalize pixel values
transform = transforms.Compose(
    [transforms.ToTensor(), # converts to tensor
     transforms.Normalize((0.5,), (0.5,))]) # normalizes the RGB values

# load cifar-10
# train
trainset = torchvision.datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=False)

# test
testset = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# classes that are potential labels in CIFAR-10
classes = trainset.classes

# Define CNN

In [4]:
import torch.nn as nn
import torch.nn.functional as F

# define simple CNN with 2 convolutional layers and 2 fully connected layers
class SimpleCNN(nn.Module):
  def __init__(self):
    super(SimpleCNN, self).__init__()

    # first conv layer 3 input channels (RGB), output = 32 filters, kernel = 3x3
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)

    # second conv layer
    self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)

    # max pooking layer to downsample feature maps by 2x
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    # fully connected layer - flatten from 64x8x8 to 128
    self.fc1 = nn.Linear(in_features=64*8*8, out_features=128)

    # output layer has 10 classes for CIFAR-10
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    # apply first conv -> ReLU -> pooling -> output shape
    x = self.pool(F.relu(self.conv1(x)))

    # apply second conv -> ReLU -> pooling -> output shape
    x = self.pool(F.relu(self.conv2(x)))

    # flatten 3d feature maps to 1d vector per image
    x = x.view(-1, 64*8*8)

    # fully connected layer -> ReLU
    x = F.relu(self.fc1(x))

    # final layer -> class scores as logits
    x = self.fc2(x)

    return x

# instatiate the model
model = SimpleCNN()


# Setup training loop

In [5]:
import torch.optim as optim

# define the loss function (cross entropy for multi class classification)
criterion = nn.CrossEntropyLoss()

# define the optimizer (adam optimizer)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# training loop for 5 epochs
for epoch in range(5):
  # track loss for this epoch
  running_loss = 0.0

  for images, labels in trainloader:
    # make zero the gradients from previous step
    optimizer.zero_grad()

    # forward pass: compute predicted outputs
    outputs = model(images)

    # compute loss between predicted and true labels
    loss = criterion(outputs, labels)

    # backward pass: compute gradients
    loss.backward()

    # update weights
    optimizer.step()

    # accumulate loss
    running_loss += loss.item()

  # print average loss per epoch
  print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader):.4f}")

print('Finished Training')

Epoch 1, Loss: 1.3529
Epoch 2, Loss: 0.9686
Epoch 3, Loss: 0.8053
Epoch 4, Loss: 0.6837
Epoch 5, Loss: 0.5796
Finished Training


# Evaluate on test set

In [6]:
correct = 0
total = 0

# don't compute gradients during evaluation (only happens during training)
with torch.no_grad():
  for images, labels in testloader:
    outputs = model(images)

    # get predicted class with highest score
    _, predicted = torch.max(outputs, dim=1)

    # count correct predictions
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

# print overall accuracy
print(f"Test accuracy: {100 * correct/total:.2f}%")

Test accuracy: 71.18%
