<a href="https://colab.research.google.com/github/klane/playground/blob/master/notebooks/pytorch/cnn/cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.5,), std=(0.5,)),  # [0, 1] range => [-1, 1] range
])

mnist_train = torchvision.datasets.MNIST(
    root='./data', download=True, train=True, transform=transform
)

mnist_val = torchvision.datasets.MNIST(
    root='./data', download=True, train=False, transform=transform
)

In [None]:
batch_size = 512

train_loader = torch.utils.data.DataLoader(
    mnist_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
)

val_loader = torch.utils.data.DataLoader(
    mnist_val, batch_size=batch_size, num_workers=4, pin_memory=True
)

In [None]:
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5)
        self.fc1 = nn.Linear(in_features=64 * 5 * 5, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.reshape(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

In [None]:
def train(net, data_loader, device, optim):
    net.train()
    
    for image, label in data_loader:
        # put data onto device
        image = image.to(device)
        label = label.to(device)
        
        # clear gradient
        optim.zero_grad()
        
        # forward through the network
        prediction = net(image)
        
        # compute loss and gradient
        loss = F.cross_entropy(prediction, label)
        loss.backward()
        
        # update parameters
        optim.step()

In [None]:
def evaluate(net, data_loader, device):
    net.eval()
    correct = 0

    with torch.no_grad():
        for image, label in data_loader:
            # put data onto the device
            image = image.to(device)
            label = label.to(device)

            # forward through the network, and get the predicted class
            prediction = net(image).argmax(dim=-1)

            # increment correct count
            correct += (prediction == label).sum().item()
            
    return correct / len(data_loader.dataset)

In [None]:
torch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_epochs = 10
lr = 0.01

net = MyNet().to(device)
optim = torch.optim.Adam(net.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=2, gamma=0.5)

for epoch in range(num_epochs):
    acc_train = evaluate(net, train_loader, device)
    acc_val = evaluate(net, val_loader, device)
    print('Epoch: {}\tTrain Accuracy: {:.4f}%\tValidation Accuracy: {:.4f}%'.format(epoch, acc_train * 100, acc_val * 100))
    train(net, train_loader, device, optim)
    scheduler.step()

acc_train = evaluate(net, train_loader, device)
acc_val = evaluate(net, val_loader, device)
print('Done! \tTrain Accuracy: {:.4f}%\tValidation Accuracy: {:.4f}%'.format(acc_train * 100, acc_val * 100))