In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_epochs = 20
batch_size = 128
learning_rate = 0.01

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
print(train_dataset.classes)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|███████████████████████████████████████████████████████████████████████████████| 170M/170M [02:31<00:00, 1.13MB/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [3]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
if __name__ == "__main__":
    model = SimpleCNN().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if (i + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Test Accuracy of the model on the 10000 test images: {100 * correct / total:.2f}%')

    torch.save(model.state_dict(), 'simple_cnn_cifar10.pth')

Epoch [1/20], Step [100/391], Loss: 2.0011
Epoch [1/20], Step [200/391], Loss: 1.5408
Epoch [1/20], Step [300/391], Loss: 1.3767
Epoch [2/20], Step [100/391], Loss: 1.1815
Epoch [2/20], Step [200/391], Loss: 1.1237
Epoch [2/20], Step [300/391], Loss: 1.0921
Epoch [3/20], Step [100/391], Loss: 0.9358
Epoch [3/20], Step [200/391], Loss: 0.9446
Epoch [3/20], Step [300/391], Loss: 0.9181
Epoch [4/20], Step [100/391], Loss: 0.8037
Epoch [4/20], Step [200/391], Loss: 0.8069
Epoch [4/20], Step [300/391], Loss: 0.8112
Epoch [5/20], Step [100/391], Loss: 0.7053
Epoch [5/20], Step [200/391], Loss: 0.6876
Epoch [5/20], Step [300/391], Loss: 0.6976
Epoch [6/20], Step [100/391], Loss: 0.5643
Epoch [6/20], Step [200/391], Loss: 0.5989
Epoch [6/20], Step [300/391], Loss: 0.5968
Epoch [7/20], Step [100/391], Loss: 0.4900
Epoch [7/20], Step [200/391], Loss: 0.4903
Epoch [7/20], Step [300/391], Loss: 0.5118
Epoch [8/20], Step [100/391], Loss: 0.3825
Epoch [8/20], Step [200/391], Loss: 0.3993
Epoch [8/20