In [1]:
import jax

In [2]:
from jax_smi import initialise_tracking
initialise_tracking()
# some computation...

In [3]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [5]:
import numpy as np
import jax.numpy as jnp

size1 = 1000
size2 = 2+101

x = np.arange(size1)
w = np.arange(2, size2, 1)
# w = np.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(size2//2, len(x)-size2//2):
    output.append(jnp.dot(x[i-size2//2:i+size2//2-1], w))
  return jnp.array(output)

# convolve(x, w)

In [7]:
n_devices = jax.local_device_count() 
xs = np.arange(size1 * n_devices).reshape(-1, size1)
ws = np.stack([w] * n_devices)

xs

array([[   0,    1,    2, ...,  997,  998,  999],
       [1000, 1001, 1002, ..., 1997, 1998, 1999],
       [2000, 2001, 2002, ..., 2997, 2998, 2999],
       ...,
       [5000, 5001, 5002, ..., 5997, 5998, 5999],
       [6000, 6001, 6002, ..., 6997, 6998, 6999],
       [7000, 7001, 7002, ..., 7997, 7998, 7999]])

In [6]:
ws

array([[  2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
         41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
         54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,
         67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,
         80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,
         93,  94,  95,  96,  97,  98,  99, 100, 101, 102],
       [  2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,
         41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
         54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,
         67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  7

In [15]:
jax.vmap(jax.jit(convolve))(xs, ws)

Array([[  348450,   353702,   358954, ...,  5048990,  5054242,  5059494],
       [ 5600450,  5605702,  5610954, ..., 10300990, 10306242, 10311494],
       [10852450, 10857702, 10862954, ..., 15552990, 15558242, 15563494],
       ...,
       [26608450, 26613702, 26618954, ..., 31308990, 31314242, 31319494],
       [31860450, 31865702, 31870954, ..., 36560990, 36566242, 36571494],
       [37112450, 37117702, 37122954, ..., 41812990, 41818242, 41823494]],      dtype=int32)

In [16]:
jax.vmap(convolve)(xs+3, ws)

Array([[  364206,   369458,   374710, ...,  5064746,  5069998,  5075250],
       [ 5616206,  5621458,  5626710, ..., 10316746, 10321998, 10327250],
       [10868206, 10873458, 10878710, ..., 15568746, 15573998, 15579250],
       ...,
       [26624206, 26629458, 26634710, ..., 31324746, 31329998, 31335250],
       [31876206, 31881458, 31886710, ..., 36576746, 36581998, 36587250],
       [37128206, 37133458, 37138710, ..., 41828746, 41833998, 41839250]],      dtype=int32)

In [17]:
jax.pmap(convolve)(xs+10, ws+3)

Array([[  419150,   424705,   430260, ...,  5390875,  5396430,  5401985],
       [ 5974150,  5979705,  5985260, ..., 10945875, 10951430, 10956985],
       [11529150, 11534705, 11540260, ..., 16500875, 16506430, 16511985],
       ...,
       [28194150, 28199705, 28205260, ..., 33165875, 33171430, 33176985],
       [33749150, 33754705, 33760260, ..., 38720875, 38726430, 38731985],
       [39304150, 39309705, 39315260, ..., 44275875, 44281430, 44286985]],      dtype=int32)