<a href="https://colab.research.google.com/github/jecampagne/JaxTutos/blob/main/JAX_exo_sum_image_patches.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Objectives:
The goal is to sum patches of identical size cut in a 2D image, and to see what is the trends on compilation/execution on CPU & GPU. We propose different methods, and it is up to you to propose your own.

In [2]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
from functools import partial

In [19]:
jax.__version__

'0.4.26'

In [18]:
jax.devices()

[cuda(id=0)]

# Compilation/Execution time on CPU & GPU devices

In [20]:
#clear JIT to be sure to start with fresh codes
jax.clear_caches()

expend_factor = 10  # [up-to 10 on Colab CPU]
nxwrap,nywrap = 16, 10
Nx,Ny = nxwrap*5*expend_factor,nywrap*8*expend_factor
im = 10 + jax.random.normal(jax.random.PRNGKey(1),(Nx,Ny))

## 2 nested explicit simple for loops

In [21]:
@partial(jax.jit, static_argnames=("nxwrap","nywrap"))
def make_fim_nested_for_loop(im,nxwrap,nywrap):

  fim = jnp.zeros((nywrap, nxwrap), dtype=im.dtype)

  nx = im.shape[1] // nxwrap
  ny = im.shape[0] // nywrap

  yl = 0
  for _ in range(ny):
      yh = yl + nywrap

      xl = 0
      for _ in range(nx):
          xh = xl + nxwrap
          fim = fim + im[yl:yh, xl:xh]
          xl = xh

      yl = yh

  return fim

In [22]:
%time fim_nested_for_loop = make_fim_nested_for_loop(im,nxwrap,nywrap)

CPU times: user 28.5 s, sys: 359 ms, total: 28.8 s
Wall time: 29.3 s


In [23]:
%timeit _= make_fim_nested_for_loop(im,nxwrap,nywrap).block_until_ready()

329 µs ± 5.79 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## nested 2-lax_fori_loop

In [24]:
@partial(jax.jit, static_argnames=("nxwrap","nywrap"))
def make_fim_nested_fori_loop(im,nxwrap,nywrap):
  def _body_j(j, vals):
        i, im, fim = vals

        ii = i% nywrap
        jj = j% nxwrap

        fim = fim.at[ii, jj].add(im[i, j])

        return i, im, fim

  def _body_i(i, vals):
        im,fim = vals
        _,_, fim = jax.lax.fori_loop(0, im.shape[1], _body_j, (i, im, fim))
        return im,fim

  fim = jnp.zeros((nywrap, nxwrap), dtype=im.dtype)
  _,fim = jax.lax.fori_loop(0, im.shape[0], _body_i, (im,fim))
  return fim

In [25]:
%time fim_nested_fori_loop = make_fim_nested_fori_loop(im,nxwrap,nywrap)

CPU times: user 6.26 s, sys: 12.7 ms, total: 6.27 s
Wall time: 6.24 s


In [26]:
%timeit _= make_fim_nested_fori_loop(im,nxwrap,nywrap).block_until_ready()

6.36 s ± 215 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## 1-lax_fori_loop-unrolled

In [27]:
@partial(jax.jit, static_argnames=("nxwrap","nywrap"))
def make_fim_nested_for_loop_unroll(im, nxwrap, nywrap):

    nx = im.shape[1] // nxwrap
    ny = im.shape[0] // nywrap

    def _body_ij(ij, vals):
        i, j = jnp.unravel_index(ij, (ny, nx))
        im, fim = vals
        xl = j * nxwrap
        yl = i * nywrap
        fim = fim + jax.lax.dynamic_slice(im, (yl, xl), (nywrap, nxwrap))
        return im, fim


    fim = jnp.zeros((nywrap, nxwrap), dtype=im.dtype)

    _, fim = jax.lax.fori_loop(0, ny * nx, _body_ij, (im, fim), unroll=10)

    return fim

In [28]:
%time fim_nested_for_loop_unroll = make_fim_nested_for_loop_unroll(im,nxwrap,nywrap)

CPU times: user 1.04 s, sys: 15.4 ms, total: 1.06 s
Wall time: 837 ms


In [29]:
%timeit _ =  make_fim_nested_for_loop_unroll(im,nxwrap,nywrap).block_until_ready()

2.24 ms ± 27.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Index brodcasting & vmapping

In [30]:
@partial(jax.jit, static_argnames=("windi","windj"))
def make_fim_broadcast(arr, windi,windj):
  def rolling_window_i(arr, wind):
    idx = jnp.arange(arr.shape[0] - wind + 1)[::wind, None] + jnp.arange(wind)[None, :]
    return arr[idx]
  y = rolling_window_i(arr, windi)
  y = jnp.moveaxis(y, -1, -2)
  y = jax.vmap(partial(rolling_window_i,wind=windj))(y)
  y = y.reshape(-1,windj,windi)
  return jnp.moveaxis(y, -1, -2).sum(axis=0)


In [31]:
%time fim_broadcast= make_fim_broadcast(im,nywrap,nxwrap) # pay attention swap x-y

CPU times: user 73 ms, sys: 2.03 ms, total: 75.1 ms
Wall time: 70.7 ms


In [32]:
%timeit _ =  make_fim_broadcast(im,nxwrap,nywrap).block_until_ready() # pay attention swap x-y

238 µs ± 5.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [33]:
def max_diff(x,x_ref):
  return np.max(np.abs(x-x_ref))

In [34]:
max_diff(fim_nested_fori_loop,fim_nested_for_loop), max_diff(fim_nested_for_loop_unroll,fim_nested_for_loop), max_diff(fim_broadcast,fim_nested_for_loop)

(0.0, 0.0, 1.8917489796876907e-10)