Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is Arraymancer slow or are the examples outdated? #265

Open
andreaferretti opened this issue Aug 8, 2018 · 7 comments
Open

Is Arraymancer slow or are the examples outdated? #265

andreaferretti opened this issue Aug 8, 2018 · 7 comments

Comments

@andreaferretti
Copy link
Contributor

After a long time, I decided to finally learn Arraymancer, and I was really pleased to see that the library has grown much more than I expected! Kudos to all contributors!

I started trying out the neural network example in the README, which features a simple convolutional network to learn MNIST. I ran the example exactly as it is, using -d:release, on the CPU.

It seemed to be really slow. In the time it was running, I modified a similar example I had in pytorch and started it. The example in pytorch looks like this (it may be not identical, since I just made a quick edit to have the same parameters, but it should be more or less the same)

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader
from datasets import TransformedDataset, Slice


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.linear1 = nn.Linear(800, 500)
        self.linear2 = nn.Linear(500, 10)

    def forward(self, x):
        batch_size, _, _, _ = x.size()
        x = self.conv1(x)
        x = F.relu(F.max_pool2d(x, 2))
        x = self.conv2(x)
        x = F.relu(F.max_pool2d(x, 2))
        x = x.view(batch_size, 800)
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

def run_epoch(model, criterion, optimizer, dataloaders, set, train=True):
    model.train(train)

    running_loss = 0.0
    running_corrects = 0
    count = len(dataloaders[set].dataset)

    for data in dataloaders[set]:
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        if train:
            loss.backward()
            optimizer.step()

        # Statistics
        _, preds = torch.max(outputs.data, 1)
        running_loss += loss.data[0]
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / count
    epoch_acc = 100 * running_corrects / count

    print('{} loss: {:.4f} accuracy: {:.2f}%'.format(set, epoch_loss, epoch_acc))


def train(model, criterion, optimizer, dataloaders):
    num_epochs = 5

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        run_epoch(model, criterion, optimizer, dataloaders, 'train', train=True)
        run_epoch(model, criterion, optimizer, dataloaders, 'validation', train=False)

def main():
    torch.manual_seed(1)
    batch_size = 32
    model = Net()

    print(model)

    criterion = nn.CrossEntropyLoss()  # LogSoftmax + ClassNLL Loss
    optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

    mnist = MNIST("data/mnist", train=True, download=True)
    mnist_test = MNIST("data/mnist", train=False, download=True)
    transform = transforms.ToTensor()
    dataset = TransformedDataset(mnist, transform)
    train_size = len(mnist) * 5 // 6
    trainset = Slice(dataset, 0, train_size)
    valset = Slice(dataset, train_size, len(mnist))
    testset = TransformedDataset(mnist_test, transform)
    dataloaders = {
        'train': DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4),
        'validation': DataLoader(valset, batch_size=batch_size, shuffle=True, num_workers=4),
        'test': DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4),
    }

    train(model, criterion, optimizer, dataloaders)
    run_epoch(model, criterion, optimizer, dataloaders, 'test', train=False)

if __name__ == '__main__':
    main()

This ran noticeably faster - namely 4m14s for 5 epochs, while Arraymancer took 10m.

It seems I am doing something wrong with Arraymancer - or the example is possibly outdated - but I am not really sure :-?

@mratsim
Copy link
Owner

mratsim commented Aug 10, 2018

Interesting, it's been a while since I've worked on the internals and benchmarked Arraymancer. I will come back to it in 2~3 weeks. Are you using devel? because there was a regression in stable:

#221 and nim-lang/Nim#7743

@andreaferretti
Copy link
Contributor Author

I updated arraymancer @#head and noticed a big improvement wrt v0.4. Now the timing is 6m17s, much better!

@mratsim
Copy link
Owner

mratsim commented Aug 10, 2018

Still surprising.

Is it on Nim devel as well?

Also are you using -d:openmp, with OpenMP I expect Arraymancer to be faster than PyTorch.

@andreaferretti
Copy link
Contributor Author

I tried to use -d:openmp (had to modify clang path to use brew clang, since stock clang on Mac does not support openmp), and now I am down to 5m8s, even better.

I have changed the pytorch code a little -the following should be self contained if you wat to try and compare. It runs in 3m2s on my macbook - which is less than it did the other time - only now I realize I was running the two together, altering the (already not very scientific) benchmark.

It should do more or less what arraymancer example does (save from some prints) - but please check, I may be missing something

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader, Dataset


class TransformedDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.linear1 = nn.Linear(800, 500)
        self.linear2 = nn.Linear(500, 10)

    def forward(self, x):
        batch_size, _, _, _ = x.size()
        x = self.conv1(x)
        x = F.relu(F.max_pool2d(x, 2))
        x = self.conv2(x)
        x = F.relu(F.max_pool2d(x, 2))
        x = x.view(batch_size, 800)
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

def run_epoch(model, criterion, optimizer, dataloaders, set, train=True):
    model.train(train)

    running_loss = 0.0
    running_corrects = 0
    count = len(dataloaders[set].dataset)

    for data in dataloaders[set]:
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        if train:
            loss.backward()
            optimizer.step()

        # Statistics
        _, preds = torch.max(outputs.data, 1)
        running_loss += loss.data[0]
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / count
    epoch_acc = 100 * running_corrects / count

    print('{} loss: {:.4f} accuracy: {:.2f}%'.format(set, epoch_loss, epoch_acc))


def train(model, criterion, optimizer, dataloaders):
    num_epochs = 5

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        run_epoch(model, criterion, optimizer, dataloaders, 'train', train=True)
        run_epoch(model, criterion, optimizer, dataloaders, 'validation', train=False)

def main():
    torch.manual_seed(1)
    batch_size = 32
    model = Net()

    print(model)

    criterion = nn.CrossEntropyLoss()  # LogSoftmax + ClassNLL Loss
    optimizer = optim.SGD(model.parameters(), lr=1e-2)

    mnist = MNIST("data/mnist", train=True, download=True)
    mnist_test = MNIST("data/mnist", train=False, download=True)
    transform = transforms.ToTensor()
    trainset = TransformedDataset(mnist, transform)
    testset = TransformedDataset(mnist_test, transform)
    dataloaders = {
        'train': DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=0),
        'validation': DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0)
    }

    train(model, criterion, optimizer, dataloaders)

if __name__ == '__main__':
    main()

@mratsim
Copy link
Owner

mratsim commented Oct 25, 2018

I didn't investigate this specific example yet but while working on the future Arraymancer backend, Laser, I noticed a huge perf bottleneck in Nim's max. Due to it not being inline, the compiler is unable to vectorize it with SSE or AVX instructions, impacting all relu call and maxpooling layers, you can see an in-depth explanation here which shows a 7x slowdown compared to an SSE3 implementation saturates the memory bandwidth.

There might be other issues. I will benchmark carefully this when switching the backend.

@mratsim
Copy link
Owner

mratsim commented Dec 9, 2018

While working on Laser I've looked into many areas and found huge slowness in the exp and log function from the C standard library <math.h>, that notably bottlenecks sigmoid and softmax.

Benchmark.

On my machine with all optimisations, <math.h> can achieve about 160 millions exponentials per second per thread while the fastest can achieve 1.3 billions (so about 10x faster).

PyTorch is not using the fastest implementation but theirs can still achieve 900 millions exponentials per second. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/avx_mathfun.h

I.e. I can't trust anyone :/.

@andreaferretti
Copy link
Contributor Author

Wow, one really has to be vigilant! :-o Although I guess that with the increasing trend towards using lower and lower precision (16 bit, 8 bit even) it does not make much sense to use exp and log functions that were originally designed to be accurate for scientific computing - I guess that even crude approximations to exponentials could perform relatively well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants