In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

In [16]:
# device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [83]:
# hyperparameters
input_size = 28 * 28
hidden_size = 100
num_of_classes = 10   # MNIST digits
num_epochs = 100
batch_size = 32
learning_rate = 0.001

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [84]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

In [85]:
# CIFAR10
train_dataset = torchvision.datasets.CIFAR10(root="./data/", train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transform, download=False)

# Dataloader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified


In [86]:
# model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
        self.fc1 = nn.Linear(in_features=16*5*5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.pool1(out)
        out = self.conv2(out)
        out = self.pool2(out)
        out = out.view(-1, 16*5*5)   # flattening
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

model = CNN().to(device)

In [87]:
# loss and optimizer
criterion = nn.CrossEntropyLoss()   # automatically performes softmax at the final layer
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

In [91]:
# training loop
total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # placing data to device
        images = images.to(device)
        labels = labels.to(device)
        # forward pass
        outputs = model(images)
        # loss
        loss = criterion(outputs, labels)
        # backpropagation & update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if not (i+1)%100:
            print(f"Epoch {epoch+1}/{num_epochs}: step {i+1}/{total_steps}: loss: {loss:.4f}")

Epoch 1/2: step 100/1563: loss: 1.3918
Epoch 1/2: step 200/1563: loss: 1.2851
Epoch 1/2: step 300/1563: loss: 1.6599
Epoch 1/2: step 400/1563: loss: 1.4148
Epoch 1/2: step 500/1563: loss: 1.3275
Epoch 1/2: step 600/1563: loss: 1.2940


KeyboardInterrupt: 

In [89]:
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    n_class_correct = [0 for i in range(10)]
    n_class_samples = [0 for i in range(10)]
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        # max returns (value ,index)
        _, predicted = torch.max(outputs, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()
        
        for i in range(len(predicted)):
        # for i in range(batch_size):
            label = labels[i]
            pred = predicted[i]
            if (label == pred):
                n_class_correct[label] += 1
            n_class_samples[label] += 1

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network: {acc} %\n')

    for i in range(10):
        acc = 100.0 * n_class_correct[i] / n_class_samples[i]
        print(f'Accuracy of {classes[i]}: {acc} %')

Accuracy of the network: 7.33 %

Accuracy of plane: 39.1 %
Accuracy of car: 0.0 %
Accuracy of bird: 34.2 %
Accuracy of cat: 0.0 %
Accuracy of deer: 0.0 %
Accuracy of dog: 0.0 %
Accuracy of frog: 0.0 %
Accuracy of horse: 0.0 %
Accuracy of ship: 0.0 %
Accuracy of truck: 0.0 %
