In [112]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms

Download and build CIFAR10 dataset

In [119]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

dataset = torchvision.datasets.CIFAR10(
    './data/',
    download=True,
    transform=transforms.ToTensor()
)

train_ds, val_ds, test_ds = torch.utils.data.random_split(dataset, [.6, .2, .2])
train_dl, val_dl, test_dl = [
    torch.utils.data.DataLoader(
        ds,
        batch_size=32,
        shuffle=True,
    ) for ds in [train_ds, val_ds, test_ds]]

Construct convnet

In [120]:
class CNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.pool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(1152, 64)
        self.fc2 = nn.Linear(64, 10)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

Training

In [121]:
def train_epoch(net, dataloader, loss_fn, optimizer):
    net.train()
    size = len(dataloader.dataset)
    for batch, (images, labels) in enumerate(dataloader):
        optimizer.zero_grad()
        outputs = net(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), batch * 32 + len(images)
            print(f'loss: {loss:5f} [{current:>5d}/{size:>5d}]')

In [122]:
def validate_epoch(net, dataloader, loss_fn):
    net.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for images, labels in dataloader:
            outputs = net(images)
            test_loss += loss_fn(outputs, labels).item()
            correct += (outputs.argmax(1) == labels).type(torch.float).sum().item()
        
    test_loss /= num_batches
    correct /= size
    print(f'Accuracy: {(100*correct):>0.1f}%, Avg Loss: {test_loss:>5f} \n')

In [123]:
epochs = 3
net = CNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

In [124]:
for epoch in range(epochs):
    print(f'EPOCH {epoch}:')
    train_epoch(net, train_dl, loss_fn, optimizer)
    validate_epoch(net, val_dl, loss_fn)

EPOCH 0:
loss: 2.283231 [   32/30000]
loss: 1.960467 [ 3232/30000]
loss: 2.110461 [ 6432/30000]
loss: 1.966350 [ 9632/30000]
loss: 1.924716 [12832/30000]
loss: 2.421471 [16032/30000]
loss: 1.529417 [19232/30000]
loss: 1.846909 [22432/30000]
loss: 1.876040 [25632/30000]
loss: 1.715833 [28832/30000]
Accuracy: 37.6%, Avg Loss: 1.667424 

EPOCH 1:
loss: 1.790999 [   32/30000]
loss: 1.657287 [ 3232/30000]
loss: 1.621968 [ 6432/30000]
loss: 1.668329 [ 9632/30000]
loss: 1.323678 [12832/30000]
loss: 1.646778 [16032/30000]
loss: 1.877646 [19232/30000]
loss: 1.411642 [22432/30000]
loss: 1.456230 [25632/30000]
loss: 1.689106 [28832/30000]
Accuracy: 41.8%, Avg Loss: 1.570738 

EPOCH 2:
loss: 1.610683 [   32/30000]
loss: 1.841908 [ 3232/30000]
loss: 1.390561 [ 6432/30000]
loss: 1.447668 [ 9632/30000]
loss: 1.743115 [12832/30000]
loss: 1.896026 [16032/30000]
loss: 1.405514 [19232/30000]
loss: 1.260999 [22432/30000]
loss: 1.642641 [25632/30000]
loss: 1.833020 [28832/30000]
Accuracy: 42.3%, Avg Loss: 