<a href="https://colab.research.google.com/github/iree-org/iree-jax/blob/colab/conways_game_of_life.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/iree-org/iree-jax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/iree-org/iree-jax
  Cloning https://github.com/iree-org/iree-jax to /tmp/pip-req-build-ncl8bud2
  Running command git clone -q https://github.com/iree-org/iree-jax /tmp/pip-req-build-ncl8bud2
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting jaxlib>=0.3.15
  Downloading jaxlib-0.3.15-cp37-none-manylinux2014_x86_64.whl (72.0 MB)
[K     |████████████████████████████████| 72.0 MB 122 kB/s 
[?25hCollecting iree-runtime>=20220811.232
  Downloading iree_runtime-20220811.232-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[K     |████████████████████████████████| 2.2 MB 24.2 MB/s 
[?25hCollecting jax>=0.3.16
  Downloading jax-0.3.16.tar.gz (1.0 MB)
[K     |████████████████████████████████| 1.0 MB 69.7 MB/s 
[?25hColl

In [None]:
import iree
import iree.jax
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Define the basic computation equivalent to conways game of life.
def conways_game(input, it):
  def body(i, x):
    # Count the number of living cells in each 3x3 block, this includes the
    # cell being checked.
    pool = jax.lax.reduce_window(x, 0, lax.add, (3, 3), window_strides=(1,1), padding=((1, 1), (1, 1)))

    # If there are 4 entries and the cell is alive, it means there are 3 living
    # neighbours, so the cell stays alive. 
    stay_alive = jnp.logical_and(pool, pool == 4)

    # If there are 3 living cells we know either it has 2 neighbours and is
    # alive OR it has 3 neighbours and is dead. In both cases this cell becomes
    # a living cell in the next step.
    become_alive = pool == 3

    # If either case is true we 
    alive = jnp.logical_or(stay_alive, become_alive).astype(np.int32)

    return alive

  # Iterate `it` number of iterations.
  return jax.lax.fori_loop(0, it, body, input)


In [None]:
# Initialize the board size and initial state randomly.
width = 128
height = 128
x0 = (np.random.randint(0, 256, size=(width, height)) > 240).astype(np.int32)
xn = x0

# Setup the game state.
class Conways(iree.jax.Program):

  # Store the internal board state.
  _x = xn

  @iree.jax.kernel
  def _conway(x, it):
    return conways_game(x, it)
  
  def main(self, it=iree.jax.like(1)):
    self._x = self._conway(self._x, it)
    return self._x


program = iree.jax.IREE.compile_program(Conways())
runtime = program.runtime_module

In [None]:
images = []
for i in range(100):
  xn = runtime["main"](np.asarray(1, dtype=np.int32))
  images.append(np.asarray(xn))

images = np.stack(images)

import plotly.express as px
import plotly.graph_objects as go
fig = px.imshow(images.astype(np.single), animation_frame=0)
fig.update(layout_coloraxis_showscale=False)
fig.update_traces(showscale=False)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)


fig.show()
