<a href="https://colab.research.google.com/github/jecampagne/JaxTutos/blob/main/JAX_exo_sum_image_patches.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Objectives:
The goal is to sum patches of identical size cut in a 2D image, and to see what is the trends on compilation/execution on CPU & GPU. We propose different methods, and it is up to you to propose your own.

In [1]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
from functools import partial

In [2]:
jax.__version__

'0.4.26'

In [3]:
jax.devices()

[CpuDevice(id=0)]

# Compilation/Execution time on CPU & GPU devices

In [4]:
#clear JIT to be sure to start with fresh codes
jax.clear_caches()

expend_factor = 10  # [up-to 10 on Colab CPU]
nxwrap,nywrap = 16, 10
Nx,Ny = nxwrap*5*expend_factor,nywrap*8*expend_factor
im = jax.random.normal(jax.random.PRNGKey(1),(Nx,Ny))

## 2 nested explicit simple for loops

In [5]:
@partial(jax.jit, static_argnames=("nxwrap","nywrap"))
def make_fim_nested_for_loop(im,nxwrap,nywrap):

  fim = jnp.zeros((nywrap, nxwrap), dtype=im.dtype)

  nx = im.shape[1] // nxwrap
  ny = im.shape[0] // nywrap

  yl = 0
  for _ in range(ny):
      yh = yl + nywrap

      xl = 0
      for _ in range(nx):
          xh = xl + nxwrap
          fim = fim + im[yl:yh, xl:xh]
          xl = xh

      yl = yh

  return fim

In [6]:
%time fim_nested_for_loop = make_fim_nested_for_loop(im,nxwrap,nywrap)

CPU times: user 12.1 s, sys: 267 ms, total: 12.4 s
Wall time: 21.7 s


In [7]:
%timeit _= make_fim_nested_for_loop(im,nxwrap,nywrap).block_until_ready()

1.18 ms ± 43.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## nested 2-lax_fori_loop

In [8]:
@partial(jax.jit, static_argnames=("nxwrap","nywrap"))
def make_fim_nested_fori_loop(im,nxwrap,nywrap):
  def _body_j(j, vals):
        i, im, fim = vals

        ii = i% nywrap
        jj = j% nxwrap

        fim = fim.at[ii, jj].add(im[i, j])

        return i, im, fim

  def _body_i(i, vals):
        im,fim = vals
        _,_, fim = jax.lax.fori_loop(0, im.shape[1], _body_j, (i, im, fim))
        return im,fim

  fim = jnp.zeros((nywrap, nxwrap), dtype=im.dtype)
  _,fim = jax.lax.fori_loop(0, im.shape[0], _body_i, (im,fim))
  return fim

In [9]:
%time fim_nested_fori_loop = make_fim_nested_fori_loop(im,nxwrap,nywrap)

CPU times: user 161 ms, sys: 2.94 ms, total: 164 ms
Wall time: 184 ms


In [10]:
%timeit _= make_fim_nested_fori_loop(im,nxwrap,nywrap).block_until_ready()

4.68 ms ± 339 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## 1-lax_fori_loop-unrolled

In [11]:
@partial(jax.jit, static_argnames=("nxwrap","nywrap"))
def make_fim_nested_for_loop_unroll(im, nxwrap, nywrap):

    nx = im.shape[1] // nxwrap
    ny = im.shape[0] // nywrap

    def _body_ij(ij, vals):
        i, j = jnp.unravel_index(ij, (ny, nx))
        im, fim = vals
        xl = j * nxwrap
        yl = i * nywrap
        fim = fim + jax.lax.dynamic_slice(im, (yl, xl), (nywrap, nxwrap))
        return im, fim


    fim = jnp.zeros((nywrap, nxwrap), dtype=im.dtype)

    _, fim = jax.lax.fori_loop(0, ny * nx, _body_ij, (im, fim), unroll=10)

    return fim

In [12]:
%time fim_nested_for_loop_unroll = make_fim_nested_for_loop_unroll(im,nxwrap,nywrap)

CPU times: user 370 ms, sys: 1.15 ms, total: 371 ms
Wall time: 369 ms


In [13]:
%timeit _ =  make_fim_nested_for_loop_unroll(im,nxwrap,nywrap).block_until_ready()

263 µs ± 12.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Index brodcasting & vmapping

In [14]:
@partial(jax.jit, static_argnames=("windi","windj"))
def make_fim_broadcast(arr, windi,windj):
  def rolling_window_i(arr, wind):
    idx = jnp.arange(arr.shape[0] - wind + 1)[::wind, None] + jnp.arange(wind)[None, :]
    return arr[idx]
  y = rolling_window_i(arr, windi)
  y = jnp.moveaxis(y, -1, -2)
  y = jax.vmap(partial(rolling_window_i,wind=windj))(y)
  y = y.reshape(-1,windj,windi)
  return jnp.moveaxis(y, -1, -2).sum(axis=0)


In [15]:
%time fim_broadcast= make_fim_broadcast(im,nywrap,nxwrap) # pay attention swap x-y

CPU times: user 105 ms, sys: 2.95 ms, total: 108 ms
Wall time: 112 ms


In [16]:
%timeit _ =  make_fim_broadcast(im,nxwrap,nywrap).block_until_ready() # pay attention swap x-y

2.05 ms ± 49.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
def max_diff(x,x_ref):
  return np.max(np.abs(x-x_ref))

In [18]:
max_diff(fim_nested_fori_loop,fim_nested_for_loop), max_diff(fim_nested_for_loop_unroll,fim_nested_for_loop), max_diff(fim_broadcast,fim_nested_for_loop)

(0.0, 0.0, 1.8917489796876907e-10)