Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian blur cpu performance #46

Open
mgoulao opened this issue Sep 10, 2022 · 12 comments
Open

Gaussian blur cpu performance #46

mgoulao opened this issue Sep 10, 2022 · 12 comments
Labels
enhancement New feature or request

Comments

@mgoulao
Copy link

mgoulao commented Sep 10, 2022

I have been doing some experiments with PIX since it allows computing image augmentations in the GPU in contrast to torchvision which computes in the CPU and requires multiple workers to avoid bottlenecks. When performing some very simple timeit examples I observed a very high time when performing a gaussian blur in the CPU. I created a simple Colab notebook to demonstrate these experiments. I even tested transferring the image to CPU before performing the blur but it doesn't seem to make any difference. I was wondering if this is intended and I should not rely on CPU computations at all or if something is yet to be optimized for CPU computation.

@claudiofantacci
Copy link
Collaborator

Hi @mgoulao, thanks for reaching out! Yeah indeed I've also tested this and is not performing quite well on CPU. Transferring the image to CPU only helps a little, it's a gain of few us over a several ms operation. This is not technically intended, the goal we try to achieve with PIX is to have implementations that perform well on TPUs/GPUs, taking what we get as a result of this when running on CPUs. This doesn't mean, of course, that we don't want/have to improve CPU implementation as well 😄 Feel free to submit a PR with any optimisation for CPU!

@ASEM000
Copy link

ASEM000 commented Oct 2, 2022

Hi @mgoulao

I made a JAX package for stencil computation that can be used to calculate the gaussian blur.

I checked the performance of kernex vs dm_pix on Colab CPU using the following code.
It seems that kernex backed convolution is faster on the CPU for this specific function.

Hope this helps
Best.

# !pip install dm_pix
# !pip install kernex 

import jax 
import jax.numpy as jnp
import kernex as kex
import dm_pix
import numpy.testing as npt 

def gaussian_blur(image, sigma, kernel_size):
    x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)
    w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))
    w = jnp.outer(w, w)
    w = w / w.sum()

    @kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
    def conv(x):
        return jnp.sum(x * w)    
    
    return conv(image)

sigma = 1.
kernel_size=5


gaussian_blur_pix = jax.jit(lambda x: dm_pix.gaussian_blur(x,sigma, kernel_size))
gaussian_blur_kex = jax.jit(lambda x: gaussian_blur(x, sigma, kernel_size))

x = jax.random.uniform(jax.random.PRNGKey(0), (512,512))
xx = jnp.expand_dims(x, axis=2)
npt.assert_allclose(gaussian_blur_pix(xx)[:,:,0], gaussian_blur_kex(x), atol=1e-5)

# warm up
gaussian_blur_pix(xx)
gaussian_blur_kex(x)

%timeit gaussian_blur_pix(xx).block_until_ready()
%timeit gaussian_blur_kex(x).block_until_ready()
111 ms ± 40 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
11.1 ms ± 3.61 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

On colab GPU its seems that kernex performs a bit better

324 µs ± 111 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
200 µs ± 4.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

@claudiofantacci
Copy link
Collaborator

Thanks for reporting this as well! I'm a bit short of time at the moment, and for the whole October I'm afraid. I'll try to have a look asap, or later beginning of November. In the meantime, if you come up with a better implementation that works as well on CPU without extra dependencies, feel free to submit a PR! 🚀

@ASEM000
Copy link

ASEM000 commented Oct 3, 2022

Noted thanks,
I will try to contribute in the coming days.
Best.

@claudiofantacci claudiofantacci added the enhancement New feature or request label Oct 4, 2022
@ASEM000
Copy link

ASEM000 commented Oct 6, 2022

Hey,

I implemented dm_pix.gaussian_blur with no extra dependencies in this colab

you can find testing and benchmarking against the depthwise-based implementation.
On colab CPU I'm getting the following speed up based on timeit average time for the jitted version of both implementations.

# average time ratio pix/kex for 3x3 kernel
# (64, 64, 1):	        12.17
# (128, 128, 1):	14.70
# (256, 256, 1):	17.38
# (512, 512, 1):	16.37
# (64, 64, 32):	62.64
# (128, 128, 32):	44.88
# (256, 256, 32):	36.19
# (512, 512, 32):	36.60
# (64, 64, 64):	42.34
# (128, 128, 64):	80.46
# (256, 256, 64):	57.42
# (512, 512, 64):	54.94

for GPU, the speed-up ratio is

# average time ratio pix/kex for 3x3 kernel
# (64, 64, 1):	        1.76
# (128, 128, 1):	1.87
# (256, 256, 1):	1.82
# (512, 512, 1):	1.98
# (64, 64, 32):	1.81
# (128, 128, 32):	2.67
# (256, 256, 32):	2.72
# (512, 512, 32):	5.24
# (64, 64, 64):	2.96
# (128, 128, 64):	1.78
# (256, 256, 64):	3.22
# (512, 512, 64):	8.81

Let me know if it's suitable for a PR

Best.

@claudiofantacci
Copy link
Collaborator

Thanks @ASEM000, I'll have a look at it as soon as I can, unfortunately that will probably be end of month 😭

@claudiofantacci
Copy link
Collaborator

I just skimmed through the code, so without checking the implementation details.
When you say kex there, you mean the new implementation which is without kex or extra dependency. Is this right?

@ASEM000
Copy link

ASEM000 commented Oct 6, 2022

Yes, you are right; sorry for the typo.

@claudiofantacci
Copy link
Collaborator

That's ok. Skimming through, looks good, but please let's resume this EOM so I have more time to look into the code and give proper advices for submitting a PR 😄

@claudiofantacci
Copy link
Collaborator

I'm finally back. I'll try to look into this asap!

@ASEM000
Copy link

ASEM000 commented Dec 9, 2022

Hello,
Any updates or feedback?

Additionally, I implemented a Gaussian filter based on FFT depthwise convolution, which should be faster for large kernels.
https://github.com/ASEM000/serket/blob/main/serket/nn/blur.py
Let me know if you are interested, so I can provide no extra dependencies version

@claudiofantacci
Copy link
Collaborator

Hey @ASEM000, I have not forgotten about this 😄
I've been quite busy and should finally be back to normal work regime, I'll try to look at all this asap 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants