In [1]:
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.set_default_device(device)
torch.set_num_threads(1)

In [3]:
dataloader_kwargs = { 'batch_size': 32 }
if use_cuda:
    dataloader_kwargs.update({ 'num_workers': 4, 'pin_memory': True, 'shuffle': True })

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, generator=torch.Generator(device), **dataloader_kwargs)
test_loader = DataLoader(test_dataset, generator=torch.Generator(device), **dataloader_kwargs)

Files already downloaded and verified


In [5]:
def conv_block(
        in_channels: int, 
        out_channels: int, 
        kernel_size: int = 3, 
        padding: int = 1) -> nn.Sequential:
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Dropout(0.2)
    )

In [6]:
class CIFAR10Model(nn.Module):
    def __init__(self):
        super(CIFAR10Model, self).__init__()

        self.seq = nn.Sequential(
            conv_block(3, 64),
            conv_block(64, 128),
            conv_block(128, 256),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(256, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.seq(x)
    
model = CIFAR10Model().to(device)

In [7]:
sum(p.numel() for p in model.parameters())

2199114

In [8]:
critrion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = StepLR(optimizer, step_size=2, gamma=0.75)

In [9]:
def train(model: nn.Module, device: torch.device, train_loader: DataLoader, optimizer: optim.Optimizer, epoch: int):
    model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        data: torch.Tensor = data.to(device)
        target: torch.Tensor = target.to(device)

        start = time.perf_counter_ns()
        optimizer.zero_grad()
        output: torch.Tensor = model(data)
        loss: torch.Tensor = critrion(output, target)
        loss.backward()
        optimizer.step()
        end = (time.perf_counter_ns() - start) / 1000.0 / 1000.0
        if batch_idx % 100 == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}\tTime: {end:.3f}ms")

In [10]:
def test(model: nn.Module, device: torch.device, test_loader: DataLoader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data: torch.Tensor = data.to(device)
            target: torch.Tensor = target.to(device)

            output: torch.Tensor = model(data)
            test_loss += critrion(output, target).item()
            pred: torch.Tensor = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n")

In [13]:
epochs = 1
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()


Test set: Average loss: 0.0122, Accuracy: 8843/10000 (88%)


Test set: Average loss: 0.0122, Accuracy: 8846/10000 (88%)


Test set: Average loss: 0.0120, Accuracy: 8870/10000 (89%)


Test set: Average loss: 0.0120, Accuracy: 8885/10000 (89%)


Test set: Average loss: 0.0121, Accuracy: 8892/10000 (89%)


Test set: Average loss: 0.0120, Accuracy: 8894/10000 (89%)


Test set: Average loss: 0.0126, Accuracy: 8851/10000 (89%)


Test set: Average loss: 0.0120, Accuracy: 8895/10000 (89%)


Test set: Average loss: 0.0122, Accuracy: 8887/10000 (89%)


Test set: Average loss: 0.0122, Accuracy: 8898/10000 (89%)


Test set: Average loss: 0.0123, Accuracy: 8876/10000 (89%)



KeyboardInterrupt: 

In [12]:
# Save the model
# torch.save(model.state_dict(), "cifar10_20epochs.pt")