<a href="https://colab.research.google.com/github/hieuza/fun/blob/main/jax_pmap_and_jax_lax_all_gather.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Explore jax.pmap and jax.lax.all_gather functions.

Inspired by https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.all_gather.html

In [4]:
# Understand jax.lax.pmap
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
import numpy as np


# jax.pmap + jax.lax.all_gather

In [8]:
x = np.arange(4)
def f(x):
  return jax.lax.all_gather(x, 'i')

def g(x):
  return jax.lax.all_gather(x, 'i') + x

fy = jax.pmap(f, axis_name='i')(x)
print('-'*80)
print(fy)

gy = jax.pmap(g, axis_name='i')(x)
print('-'*80)
print(gy)


--------------------------------------------------------------------------------
[[0 1 2 3]
 [0 1 2 3]
 [0 1 2 3]
 [0 1 2 3]]
--------------------------------------------------------------------------------
[[0 1 2 3]
 [1 2 3 4]
 [2 3 4 5]
 [3 4 5 6]]


Explain the result `fy`

```python
[[0 1 2 3]   # x == 0 on this device, all_gather == [0 1 2 3] on each device
 [0 1 2 3]   # ...
 [0 1 2 3]   # x = 2 on this device, all_gather == [0 1 2 3] ...
 [0 1 2 3]]
```

Explain the result `gy`
```python
[[0 1 2 3]   # == [0 1 2 3] + 0; [0 1 2 3] is the result of all_gather in each device; 0 is specific value of x on this device.
 [1 2 3 4]   # == [0 1 2 3] + 1
 [2 3 4 5]   # == [0 1 2 3] + 2
 [3 4 5 6]]  # == [0 1 2 3] + 3
 ```

# jax.lax.all_gather(..., axis_index_groups=)

In [13]:

x = np.arange(16).reshape(4, 4)
print('x', '-'*80)
print(x)


def f(x):
  return jax.lax.all_gather(x, 'i') + x

def g(x):
  return jax.lax.all_gather(x, 'i', axis_index_groups=[[0, 2], [3, 1]]) + x

fy = jax.pmap(f, axis_name='i')(x)
print('fy', '-'*80)
print(fy)

gy = jax.pmap(g, axis_name='i')(x)
print('gy', '-'*80)
print(gy)


x --------------------------------------------------------------------------------
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
fy --------------------------------------------------------------------------------
[[[ 0  2  4  6]
  [ 4  6  8 10]
  [ 8 10 12 14]
  [12 14 16 18]]

 [[ 4  6  8 10]
  [ 8 10 12 14]
  [12 14 16 18]
  [16 18 20 22]]

 [[ 8 10 12 14]
  [12 14 16 18]
  [16 18 20 22]
  [20 22 24 26]]

 [[12 14 16 18]
  [16 18 20 22]
  [20 22 24 26]
  [24 26 28 30]]]
gy --------------------------------------------------------------------------------
[[[ 0  2  4  6]
  [ 8 10 12 14]]

 [[16 18 20 22]
  [ 8 10 12 14]]

 [[ 8 10 12 14]
  [16 18 20 22]]

 [[24 26 28 30]
  [16 18 20 22]]]


Explain the result `fy`

```python
[
 # x == [0 1 2 3] on this device; all_gather == [[0 1 2 3]...[12 13 14 15]] on all devices.
 # this is the result of x + all_gather on this device.
 [[ 0  2  4  6]
  [ 4  6  8 10]
  [ 8 10 12 14]
  [12 14 16 18]]

 # Similarly, x == [4 5 6 7] on this device..., same gather as above.
 [[ 4  6  8 10]
  [ 8 10 12 14]
  [12 14 16 18]
  [16 18 20 22]]

 # And so on, x == [8 9 10 11] on this device...
 [[ 8 10 12 14]
  [12 14 16 18]
  [16 18 20 22]
  [20 22 24 26]]

 [[12 14 16 18]
  [16 18 20 22]
  [20 22 24 26]
  [24 26 28 30]]]
  ```

  Explan the result `gy` with `axis_index_groups=[[0, 2], [3, 1]]`

```python
[
 # x == [0 1 2 3] on this device.
 # all_gather == [[0 1 2 3] [8 9 10 11]] because of the groups [0 2].
 # x + all_gather == ...
 [[ 0  2  4  6]
  [ 8 10 12 14]]

 # x == [4 5 6 7]
 # all_gather == [[12 13 14 15] [4 5 6 7]] because of groups [3 1].
 # x + all_gather == ...
 [[16 18 20 22]
  [ 8 10 12 14]]

 # x == [8 9 10 11]
 # all_gather == [[0 1 2 3] [8 9 10 11]]
 # and so on
 [[ 8 10 12 14]
  [16 18 20 22]]

 # x == [12 13 14 15]
 # all_gather == [[12 13 14 15] [4 5 6 7]]
 [[24 26 28 30]
  [16 18 20 22]]]
```

