In [28]:
%env CUDA_VISIBLE_DEVICES 0

env: CUDA_VISIBLE_DEVICES=0


In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import numpy as np

torch.manual_seed(1)

DOWNLOAD_MNIST = True

train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(),
                                        download=DOWNLOAD_MNIST, )
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
print(train_data.train_data.shape)

train_x = torch.unsqueeze(train_data.train_data, dim=1).type(torch.FloatTensor) / 255.
train_y = train_data.train_labels
print(train_x.shape)

test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000] / 255.  # Tensor on GPU
test_y = test_data.test_labels[:2000]

torch.Size([60000, 28, 28])
torch.Size([60000, 1, 28, 28])


In [30]:
class ConvModule(nn.Module):

    def __init__(self, in_channels, out_channels, order='CAB', residual=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.order = order
        self.residual = residual
        
        # 'CBA': Conv + BN + Activation
        # 'CAB': Conv + Activation + BN
        assert self.order in ['CBA', 'CAB']
        module_pack = []
        module_pack.append(nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, bias=False, padding=1))
        if order == 'CBA':
            module_pack.append(nn.BatchNorm2d(self.out_channels))
            module_pack.append(nn.ReLU())
        else:
            module_pack.append(nn.ReLU())  
            module_pack.append(nn.BatchNorm2d(self.out_channels))
    
        self.module = nn.Sequential(*module_pack)

    def forward(self, x):
        out = self.module(x)

        if self.residual:
            return out + x
        else:
            return out


class ConvLayer(nn.Module):
    
    def __init__(self, in_channels, out_channels, num_blocks, downsampling=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_blocks = num_blocks
        self.downsampling = downsampling

        self.stem = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        blocks = []
        for _ in range(num_blocks):
            blocks.append(ConvModule(self.out_channels, self.out_channels))
        self.blocks = nn.Sequential(*blocks)

        if self.downsampling:
            self.downsampling = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        if self.downsampling:
            x = self.downsampling(x)
        return x


class Classifier(nn.Module):
    def __init__(self, in_channels, base_channels, num_classes, dropout_rate=0.5, stage_blocks=[2, 2]):
        super(Classifier, self).__init__()

        self.in_channels = in_channels
        self.base_channels = base_channels
        self.num_classes = num_classes
        self.dropout_rate = dropout_rate

        self.layers = nn.ModuleList()
        for num_blocks in stage_blocks:
            self.layers.append(ConvLayer(in_channels, base_channels, num_blocks))
            in_channels = base_channels
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.warm_up_linears = nn.Sequential(*[nn.Linear(base_channels, base_channels * 2), nn.Linear(base_channels * 2, base_channels)])
        self.out_linear = nn.Linear(base_channels, num_classes)
        self.dropout = nn.Dropout2d(1 - self.dropout_rate)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = self.warm_up_linears(x)
        # x = self.dropout(x)
        output = self.out_linear(x)
        
        return output


In [31]:
EPOCH = 5
LR = 0.005
data_size = 20000
batch_size = 200

max_accuracy = -1

fc = Classifier(in_channels=1, base_channels=64, num_classes=10)

optimizer = torch.optim.Adam(fc.parameters(), lr=LR)
# loss_func = nn.MSELoss()
loss_func = nn.CrossEntropyLoss()

for epoch in range(EPOCH):
    random_indx = np.random.permutation(data_size)
    for batch_i in range(data_size // batch_size):
        indx = random_indx[batch_i * batch_size:(batch_i + 1) * batch_size]

        b_x = train_x[indx, :]
        b_y = train_y[indx]

        output = fc(b_x)
        loss = loss_func(output, b_y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch_i % 50 == 0:
            fc = fc.eval()
            test_output = fc(test_x)
            fc = fc.train()
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
            if accuracy > max_accuracy:
                max_accuracy = accuracy

            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.cpu().numpy(), '| test accuracy: %.3f' % accuracy)

print(f'max evaluation accuracy: {max_accuracy}')

# test_output = fc(test_x[:10])
# pred_y = torch.max(test_output, 1)[1].data.squeeze()  # move the computation in GPU


Epoch:  0 | train loss: 2.3736 | test accuracy: 0.108
Epoch:  0 | train loss: 0.4529 | test accuracy: 0.678
Epoch:  0 | train loss: 0.1835 | test accuracy: 0.783
Epoch:  0 | train loss: 0.0946 | test accuracy: 0.781
Epoch:  1 | train loss: 0.0901 | test accuracy: 0.941
Epoch:  1 | train loss: 0.1005 | test accuracy: 0.877
Epoch:  1 | train loss: 0.0974 | test accuracy: 0.552
Epoch:  1 | train loss: 0.0600 | test accuracy: 0.878
Epoch:  2 | train loss: 0.0802 | test accuracy: 0.805
Epoch:  2 | train loss: 0.0748 | test accuracy: 0.882
Epoch:  2 | train loss: 0.0160 | test accuracy: 0.925
Epoch:  2 | train loss: 0.0757 | test accuracy: 0.900
Epoch:  3 | train loss: 0.0339 | test accuracy: 0.972
Epoch:  3 | train loss: 0.0311 | test accuracy: 0.970
Epoch:  3 | train loss: 0.0150 | test accuracy: 0.977
Epoch:  3 | train loss: 0.0789 | test accuracy: 0.972
Epoch:  4 | train loss: 0.0317 | test accuracy: 0.938
Epoch:  4 | train loss: 0.0208 | test accuracy: 0.968
Epoch:  4 | train loss: 0.02