In [15]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import time

from simple_conv_net_train import SimpleConvNet, NNType
from simple_conv_net_func import *

# Setup

In [2]:
seed = 1
no_cuda = False
batch_size = 64
lr = 0.01
momentum = 0.5

In [3]:
torch.manual_seed(seed)
use_cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [4]:
def get_test_loader():
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=False, **kwargs)
    return test_loader

def get_1_batch(device, test_loader):
    data, target = test_loader.__iter__().__next__()
    data, target = data.to(device), target.to(device)
    return (data, target)

def test_1_batch(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            break
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

# Validate calculations and test speed

In [5]:
test_loader = get_test_loader()

In [6]:
model = SimpleConvNet(device, NNType.TORCH)

In [7]:
(data, target) = get_1_batch(device, test_loader)

In [8]:
model.eval()

SimpleConvNet(
  (conv_layer): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (fc_layer1): Linear(in_features=2880, out_features=500, bias=True)
  (fc_layer2): Linear(in_features=500, out_features=10, bias=True)
)

## Conv

In [9]:
start = time.time()
z_conv = model.conv_layer(data)
end = time.time()
print(end - start)

0.003996849060058594


In [10]:
start = time.time()
z_conv_vector = conv2d_vector(data, model.conv_layer.weight, model.conv_layer.bias, device)
end = time.time()
print(end - start)

1.1969952583312988


In [11]:
start = time.time()
with torch.no_grad():
    z_conv_scalar = conv2d_scalar(data, model.conv_layer.weight, model.conv_layer.bias, device)
end = time.time()
print(end - start)

1348.8700397014618


In [12]:
diff_mse(z_conv, z_conv_vector), diff_mse(z_conv, z_conv_scalar)

(0.0, 3.5518226001399894e-15)

## Pool

In [13]:
start = time.time()
z_pool = F.max_pool2d(z_conv, 2, 2)
end = time.time()
print(end - start)

0.0020008087158203125


In [18]:
start = time.time()
z_pool_vector = pool2d_vector(z_conv_scalar, device)
end = time.time()
print(end - start)

26.607872009277344


In [19]:
start = time.time()
z_pool_scalar = pool2d_scalar(z_conv_scalar, device)
end = time.time()
print(end - start)

18.50862216949463


In [20]:
diff_mse(z_pool, z_pool_vector), diff_mse(z_pool, z_pool_scalar)

(3.790910778513198e-15, 3.790910778513198e-15)

## Reshape

In [25]:
start = time.time()
z_pool_reshaped = z_pool.view(-1, 20*12*12)
end = time.time()
print(end - start)

0.001001119613647461


In [26]:
start = time.time()
z_pool_reshaped_vector = reshape_vector(z_pool_vector, device)
end = time.time()
print(end - start)

0.0010042190551757812


In [27]:
start = time.time()
z_pool_reshaped_scalar = reshape_scalar(z_pool_scalar, device)
end = time.time()
print(end - start)

4.903999328613281


In [28]:
diff_mse(z_pool_reshaped, z_pool_reshaped_vector), diff_mse(z_pool_reshaped, z_pool_reshaped_scalar)

(3.790910778513198e-15, 3.790910778513198e-15)

## FC1

In [32]:
start = time.time()
z_fc1 = model.fc_layer1(z_pool_reshaped)
end = time.time()
print(end - start)

0.0029981136322021484


In [33]:
start = time.time()
z_fc1_vector = fc_layer_vector(z_pool_reshaped_vector, model.fc_layer1.weight, model.fc_layer1.bias, model.device)
end = time.time()
print(end - start)

0.0030028820037841797


In [None]:
start = time.time()
with torch.no_grad():
    z_fc1_scalar = fc_layer_scalar(z_pool_reshaped_scalar, model.fc_layer1.weight, model.fc_layer1.bias, model.device)
end = time.time()
print(end - start)

In [None]:
diff_mse(z_fc1, z_fc1_vector), diff_mse(z_fc1, z_fc1_scalar)

## Relu

In [None]:
start = time.time()
z_relu = F.relu(z_fc1)
end = time.time()
print(end - start)

In [None]:
start = time.time()
z_relu_vector = relu_vector(z_fc1_vector, model.device)
end = time.time()
print(end - start)

In [None]:
start = time.time()
z_relu_scalar = relu_scalar(z_fc1_scalar, model.device)
end = time.time()
print(end - start)

In [None]:
diff_mse(z_relu, z_relu_vector), diff_mse(z_relu, z_relu_scalar)

## FC2

In [None]:
start = time.time()
z_fc2 = model.fc_layer2(z_relu)
end = time.time()
print(end - start)

In [None]:
start = time.time()
z_fc2_vector = fc_layer_vector(z_relu_vector, model.fc_layer2.weight, model.fc_layer2.bias, model.device)
end = time.time()
print(end - start)

In [None]:
start = time.time()
with torch.no_grad():
    z_fc2_scalar = fc_layer_scalar(z_relu_scalar, model.fc_layer2.weight, model.fc_layer2.bias, model.device)
end = time.time()
print(end - start)

In [None]:
diff_mse(z_fc2, z_fc2_vector), diff_mse(z_fc2, z_fc2_scalar)