Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast IndexIterator for ChainerX CUDA #8360

Merged
merged 1 commit into from
Dec 11, 2019
Merged

Conversation

emcastillo
Copy link
Member

@emcastillo emcastillo commented Oct 31, 2019

Thanks to @asi1024 @shinh

ChainerX indexer used pretty expensive int64 division and modulo operations when calculating array indexes on CUDA.

This was noticeable when arrays were not contiguous, severely affecting the execution time of even simple kernels as ElementWise ones.

This PR replaces the code for index calculation with the same one as Cupy.

In the following test time for chainerx is reduced from 0.70 secs to 0.27

import numpy as np
import cupy
import chainer
import chainerx as chx
import time

def test_bench(name, xp, alloc_fn):
    np.random.seed(42)
    x = np.random.rand(800,389,2,66).astype(np.float32)
    y = np.random.rand(800,389,2,66).astype(np.float32)
    a = alloc_fn(x)
    b = alloc_fn(y)
    a = np.swapaxes(a, 2, 3)
    b = np.swapaxes(b, 2, 3)

    for i in range(1):
        cuda = xp.multiply(a,b)

    cupy.cuda.device.Device().synchronize()
    start = time.time()
    for i in range(400):
        cuda = xp.multiply(a,b)
    cupy.cuda.device.Device().synchronize()
    total = time.time() - start
    print(name, total)

test_bench('cupy', cupy, lambda x: cupy.array(x))
test_bench('chainerx', chx, lambda x: chx.array(x, device='cuda:0'))

@asi1024 asi1024 added the cat:performance Performance in terms of speed or memory consumption. label Oct 31, 2019
@emcastillo emcastillo changed the title Fast IndexIterator for CUDA [WIP] Fast IndexIterator for CUDA Oct 31, 2019
@emcastillo emcastillo added the ChainerX Related to ChainerX. label Oct 31, 2019
@emcastillo emcastillo changed the title [WIP] Fast IndexIterator for CUDA Fast IndexIterator for CUDA Nov 1, 2019
chainerx_cc/chainerx/index_iterator.h Outdated Show resolved Hide resolved
chainerx_cc/chainerx/index_iterator.h Outdated Show resolved Hide resolved
@emcastillo
Copy link
Member Author

PTAL

@asi1024
Copy link
Member

asi1024 commented Nov 6, 2019

Jenkins, test this please.

@emcastillo
Copy link
Member Author

One of the things I don't like, is that we have GPU specific code in a code that should be generic.
This was a "hot-fix" due to the performance problem it addresses but the cuda code needs to be factored out.

Probably overriding it with a specific CudaIndexer class in the cuda_device/ tree should be a nice refactoring.

chainerx_cc/chainerx/index_iterator.h Outdated Show resolved Hide resolved
chainerx_cc/chainerx/index_iterator.h Outdated Show resolved Hide resolved
@chainer-ci
Copy link
Member

Jenkins CI test (for commit 3fb14c8, target branch master) succeeded!

@asi1024 asi1024 self-requested a review November 6, 2019 12:18
@emcastillo
Copy link
Member Author

emcastillo commented Nov 7, 2019

Regarding the test.
I've been looking where to place it but none of the test files in cuda/ seems to be appropiate.
I think that this test can be carried out in python by allocating a non-contiguous array with more than 2^31 elements for the indexer to work with zeros, add 1 inplace and then perform the summation that should be equal to the number of elements.
Adding a scalar calls the elementwise kernel which uses the indexers in n-dims if the array is not contiguous.

# This is more than 2**31 elements 
a=chainerx.zeros(shape=(64,32,6*1024*1024*128//(16*16)), dtype=chainerx.int8, device='cuda:0')
a=a.swapaxes(2,0)
a+=1
assert not a.is_contiguous
assert a.sum() == a.shape[0]*a.shape[1]*a.shape[2]

The test should be skipped if the allocation fails due to the GPU not having enough memory?

@asi1024
Copy link
Member

asi1024 commented Nov 11, 2019

This PR depends on #8389.

@asi1024 asi1024 added the st:blocked-by-another-pr State indicating that another ticket is preventing this ticket from being closed/merged. label Nov 11, 2019
@asi1024 asi1024 removed the st:blocked-by-another-pr State indicating that another ticket is preventing this ticket from being closed/merged. label Nov 21, 2019
@asi1024
Copy link
Member

asi1024 commented Nov 21, 2019

Now #8389 is merged. Could you add tests?

@emcastillo
Copy link
Member Author

Sure, I have them already, let me just rebase and push

@emcastillo emcastillo force-pushed the fast_indexer branch 4 times, most recently from c78c8fb to cce2c8b Compare November 22, 2019 01:37
@emcastillo
Copy link
Member Author

PTAL

@asi1024
Copy link
Member

asi1024 commented Nov 22, 2019

Jenkins, test this please.

@asi1024
Copy link
Member

asi1024 commented Dec 3, 2019

Jenkins, test this please.

4 similar comments
@asi1024
Copy link
Member

asi1024 commented Dec 3, 2019

Jenkins, test this please.

@asi1024
Copy link
Member

asi1024 commented Dec 3, 2019

Jenkins, test this please.

@asi1024
Copy link
Member

asi1024 commented Dec 3, 2019

Jenkins, test this please.

@asi1024
Copy link
Member

asi1024 commented Dec 3, 2019

Jenkins, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit d7a1168, target branch master) succeeded!

@asi1024
Copy link
Member

asi1024 commented Dec 3, 2019

flexCI, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 627df80, target branch master) succeeded!

@asi1024
Copy link
Member

asi1024 commented Dec 3, 2019

flexCI, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 9dcf904, target branch master) succeeded!

@asi1024
Copy link
Member

asi1024 commented Dec 4, 2019

flexCI, test this please.

1 similar comment
@asi1024
Copy link
Member

asi1024 commented Dec 4, 2019

flexCI, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit bbcbbdc, target branch master) succeeded!

@asi1024
Copy link
Member

asi1024 commented Dec 4, 2019

flexCI, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 75c9fb8, target branch master) succeeded!

@emcastillo emcastillo removed this from the v7.0.0 milestone Dec 5, 2019
@asi1024
Copy link
Member

asi1024 commented Dec 8, 2019

Jenkins, test this please.

@chainer-ci
Copy link
Member

Jenkins CI test (for commit e26af2a, target branch master) succeeded!

@emcastillo
Copy link
Member Author

Jenkins, test this please

@emcastillo
Copy link
Member Author

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 941d5e5, target branch master) succeeded!

@asi1024
Copy link
Member

asi1024 commented Dec 11, 2019

Travis failure seems unrelated to this PR. (ref. #8481)
Restarted the failing job.

@asi1024
Copy link
Member

asi1024 commented Dec 11, 2019

LGTM!

@asi1024 asi1024 merged commit a45b262 into chainer:master Dec 11, 2019
@asi1024 asi1024 added this to the v7.1.0 milestone Dec 11, 2019
@emcastillo emcastillo changed the title Fast IndexIterator for CUDA Fast IndexIterator for ChainerX CUDA Jan 16, 2020
@chainer-ci
Copy link
Member

Jenkins CI test (for commit 941d5e5, target branch master) succeeded!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:performance Performance in terms of speed or memory consumption. ChainerX Related to ChainerX.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants