<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/Parallel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

from jax import grad, jit, vmap, pmap

from jax import random
import matplotlib.pyplot as plt

import functools

# Independent model

In [3]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [4]:
x = np.arange(5)  # signal
w = np.array([2., 3., 4.])  # kernel

def convolve(w, x):  # implementation of 1D convolution
    output = []

    for i in range(1, len(x)-1):
        output.append(jnp.dot(x[i-1:i+2], w))

    return jnp.array(output)

result = convolve(w, x)
print(repr(result))

Array([11., 20., 29.], dtype=float32)


In [22]:
n_devices = jax.local_device_count()
print(f'Number of available devices: {n_devices}')

xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)

print(xs.shape, ws.shape)
jax.debug.visualize_array_sharding(jnp.array(xs))

Number of available devices: 8
(8, 5) (8, 3)


In [23]:
vmap_result = jax.vmap(convolve)(ws, xs)
print(repr(vmap_result))
jax.debug.visualize_array_sharding(vmap_result)

Array([[ 11.,  20.,  29.],
       [ 56.,  65.,  74.],
       [101., 110., 119.],
       [146., 155., 164.],
       [191., 200., 209.],
       [236., 245., 254.],
       [281., 290., 299.],
       [326., 335., 344.]], dtype=float32)


In [25]:
pmap_result = jax.pmap(convolve)(ws, xs)
print(repr(pmap_result))
jax.debug.visualize_array_sharding(pmap_result)

Array([[ 11.,  20.,  29.],
       [ 56.,  65.,  74.],
       [101., 110., 119.],
       [146., 155., 164.],
       [191., 200., 209.],
       [236., 245., 254.],
       [281., 290., 299.],
       [326., 335., 344.]], dtype=float32)


In [27]:
pmap_smarter_result = jax.pmap(convolve, in_axes=(None, 0))(w, xs)
print(repr(pmap_smarter_result))

Array([[ 11.,  20.,  29.],
       [ 56.,  65.,  74.],
       [101., 110., 119.],
       [146., 155., 164.],
       [191., 200., 209.],
       [236., 245., 254.],
       [281., 290., 299.],
       [326., 335., 344.]], dtype=float32)


In [26]:
double_pmap_result = jax.pmap(convolve)(jax.pmap(convolve)(ws, xs), xs)
print(repr(double_pmap_result))

Array([[   78.,   138.,   198.],
       [ 1188.,  1383.,  1578.],
       [ 3648.,  3978.,  4308.],
       [ 7458.,  7923.,  8388.],
       [12618., 13218., 13818.],
       [19128., 19863., 20598.],
       [26988., 27858., 28728.],
       [36198., 37203., 38208.]], dtype=float32)


# Dependent model

In [30]:
def normalized_convolution(w, x):
    output = []

    for i in range(1, len(x)-1):
        output.append(jnp.dot(x[i-1:i+2], w))

    output = jnp.array(output)

    return output / jax.lax.psum(output, axis_name='batch_dim')

res_pmap = jax.pmap(normalized_convolution, axis_name='batch_dim', in_axes=(None, 0))(w, xs)

print(repr(res_pmap))

print(f'Verify the output is normalized: {sum(res_pmap[:, 0])}')

Array([[0.00816024, 0.01408451, 0.019437  ],
       [0.04154303, 0.04577465, 0.04959785],
       [0.07492582, 0.07746479, 0.07975871],
       [0.10830861, 0.10915492, 0.10991956],
       [0.14169139, 0.14084506, 0.14008042],
       [0.17507419, 0.17253521, 0.17024128],
       [0.20845698, 0.20422535, 0.20040214],
       [0.24183977, 0.23591548, 0.23056298]], dtype=float32)
Verify the output is normalized: 1.0
