<a href="https://colab.research.google.com/github/jecampagne/JaxTutos/blob/main/sum_of_image_patches.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


# Objectives:
The goal is to sum patches of identical size from a 2D image, and to see what is the fastest way on CPU & GPU. We propose three methods but you can code your own.

In [2]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
from functools import partial

In [3]:
@partial(jax.jit, static_argnames=("nxwrap","nywrap"))
def make_fim_nested_for_loop(im,nxwrap,nywrap):

  fim = jnp.zeros((nywrap, nxwrap), dtype=im.dtype)

  nx = im.shape[1] // nxwrap
  ny = im.shape[0] // nywrap

  yl = 0
  for _ in range(ny):
      yh = yl + nywrap

      xl = 0
      for _ in range(nx):
          xh = xl + nxwrap
          fim = fim + im[yl:yh, xl:xh]
          xl = xh

      yl = yh

  return fim

In [4]:
@partial(jax.jit, static_argnames=("nxwrap","nywrap"))
def make_fim_nested_fori_loop(im,nxwrap,nywrap):
  def _body_j(j, vals):
        i, im, fim = vals

        ii = i% nywrap
        jj = j% nxwrap

        fim = fim.at[ii, jj].add(im[i, j])

        return i, im, fim

  def _body_i(i, vals):
        im,fim = vals
        _,_, fim = jax.lax.fori_loop(0, im.shape[1], _body_j, (i, im, fim))
        return im,fim

  fim = jnp.zeros((nywrap, nxwrap), dtype=im.dtype)
  _,fim = jax.lax.fori_loop(0, im.shape[0], _body_i, (im,fim))
  return fim

In [5]:
@partial(jax.jit, static_argnames=("windi","windj"))
def make_fim_roll(arr, windi,windj):
  def rolling_window_i(arr, wind):
    idx = jnp.arange(arr.shape[0] - wind + 1)[::wind, None] + jnp.arange(wind)[None, :]
    return arr[idx]
  y = rolling_window_i(arr, windi)
  y = jnp.moveaxis(y, -1, -2)
  y = jax.vmap(partial(rolling_window_i,wind=windj))(y)
  y = y.reshape(-1,windj,windi)
  return jnp.moveaxis(y, -1, -2).sum(axis=0)


In [6]:
#clear JIT to be sure to start we fresh code
jax.clear_caches()

nxwrap,nywrap = 16, 10
Nx,Ny = nxwrap*50,nywrap*80
im = jax.random.normal(jax.random.PRNGKey(1),(Nx,Ny))

In [7]:
%%time
fim_nested_for_loop = make_fim_nested_for_loop(im,nxwrap,nywrap)

CPU times: user 29.8 s, sys: 2.04 s, total: 31.8 s
Wall time: 36.6 s


In [8]:
%%time
for _ in range(10):
  _= make_fim_nested_for_loop(im,nxwrap,nywrap)

CPU times: user 619 µs, sys: 0 ns, total: 619 µs
Wall time: 628 µs


In [9]:
%%time
fim_nested_fori_loop = make_fim_nested_fori_loop(im,nxwrap,nywrap)

CPU times: user 6.75 s, sys: 46.7 ms, total: 6.8 s
Wall time: 6.89 s


In [10]:
%%time
for _ in range(10):
  _= make_fim_nested_fori_loop(im,nxwrap,nywrap)

CPU times: user 1min 3s, sys: 80.2 ms, total: 1min 3s
Wall time: 1min 3s


In [11]:
%%time
fim_roll= make_fim_roll(im,nywrap,nxwrap)

CPU times: user 84.2 ms, sys: 15 µs, total: 84.2 ms
Wall time: 144 ms


In [12]:
%%time
for _ in range(10):
  _= make_fim_roll(im,nxwrap,nywrap)

CPU times: user 71.7 ms, sys: 5.92 ms, total: 77.6 ms
Wall time: 122 ms


In [14]:
np.max(np.abs(fim_nested_fori_loop-fim_nested_for_loop)), np.max(np.abs(fim_roll-fim_nested_for_loop))

(0.0, 3.836930773104541e-13)