# Practice CNN with GPU support

- switch model, loss function and variables to GPU with `cuda()` method
- get data back from GPU with `cpu()` method, e.g. loss function

**In this MINIST case, the default CNN calculated with cpu takes ~300s, while ~50s in GPU mode.**

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [11]:
# 0 switch on/off gpu
use_gpu = torch.cuda.is_available()
device = torch.device("cuda") if use_gpu else torch.device('cpu')

In [12]:
# 1 create model
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential( # (1, 28, 28)
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2 ),      # (16, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),    # (16, 14, 14)
        )
        self.conv2 = nn.Sequential(  # (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),  # (32, 14, 14)
            nn.ReLU(),
            nn.MaxPool2d(2),  # (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   # fully connected layer, output 10 classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)   # (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

net = CNN()
net.to(device)

print(net)  # net architecture

CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)


In [13]:
# 2 load data
import torchvision
import torch.utils.data as Data

train_data = torchvision.datasets.MNIST(
    root='./dataset/', 
    train=True,
    transform=torchvision.transforms.ToTensor())

test_data = torchvision.datasets.MNIST(
    root='./dataset/', 
    train=False)

train_data.data[0]

# batch train data
train_loader = Data.DataLoader(dataset=train_data, batch_size=32, shuffle=True)

# preprocess test data
# size: (n, 28, 28) -> (n, 1, 28, 28)
# value [0, 255] -> [0, 1]
with torch.no_grad():
    test_x = Variable(torch.unsqueeze(test_data.data, dim=1)).type(torch.FloatTensor) / 255.0
test_y = test_data.targets

test_x = test_x.to(device)
test_y = test_y.to(device)

In [15]:
# 3 train and evaluate model
fun_loss = nn.CrossEntropyLoss().to(device) # cross entropy loss
optimizer = torch.optim.SGD(net.parameters(), lr=0.02) # SGD Optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=0.02)

def evaluate(x, y):
    '''
    x: (n, 1, 28, 28)
    y: (n, 10)
    '''
    out = net(x)
    y_ = torch.max(out, 1)[1].detach() # max() return (value, index)
    accuracy = sum(y_==y) / y.size(0)
    return accuracy.item(), y_


# training and testing
for epoch in range(3):
    for step, (x, y) in enumerate(train_loader): 
        # batch x, y variables
        b_x = Variable(x).to(device)
        b_y = Variable(y).to(device)

        output = net(b_x)               # ann output
        loss = fun_loss(output, b_y)    # cross entropy loss
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradients

        if step%50 == 0:
            accuracy, _ = evaluate(test_x, test_y)
            if use_gpu: loss = loss.cpu()            
            print(f'epoch: {epoch} | loss: {loss.detach().item()} | accuracy: {accuracy}')

epoch: 0 | loss: 0.009286634624004364 | accuracy: 0.1623999923467636
epoch: 0 | loss: 0.06347671896219254 | accuracy: 0.9387999773025513
epoch: 0 | loss: 0.39487653970718384 | accuracy: 0.9411999583244324
epoch: 0 | loss: 0.1898575723171234 | accuracy: 0.9490999579429626
epoch: 0 | loss: 0.14610183238983154 | accuracy: 0.9386999607086182
epoch: 0 | loss: 0.10090676695108414 | accuracy: 0.9472999572753906
epoch: 0 | loss: 0.3239684998989105 | accuracy: 0.9610999822616577
epoch: 0 | loss: 0.3144674003124237 | accuracy: 0.9562999606132507
epoch: 0 | loss: 0.027942165732383728 | accuracy: 0.9619999527931213
epoch: 0 | loss: 0.09400998800992966 | accuracy: 0.9692999720573425
epoch: 0 | loss: 0.00919688493013382 | accuracy: 0.9550999999046326
epoch: 0 | loss: 0.11639679968357086 | accuracy: 0.9537999629974365
epoch: 0 | loss: 0.051452457904815674 | accuracy: 0.957099974155426
epoch: 0 | loss: 0.013265540823340416 | accuracy: 0.9651999473571777
epoch: 0 | loss: 0.05242225155234337 | accuracy: