## Using Cloud TPU and a Local runtime in Colab

Make preparations according to Appendix C, or the following links:

* Creating a Cloud TPU https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-vms

* Preparing Jupyter and connect to a Local runtime https://research.google.com/colaboratory/local-runtimes.html


In [1]:
!pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


Needed for pretty sharding visualization

In [2]:
!pip install rich

Collecting rich
  Downloading rich-13.7.0-py3-none-any.whl (240 kB)
[K     |████████████████████████████████| 240 kB 1.3 MB/s 
[?25hCollecting markdown-it-py>=2.2.0
  Downloading markdown_it_py-3.0.0-py3-none-any.whl (87 kB)
[K     |████████████████████████████████| 87 kB 6.7 MB/s 
Collecting mdurl~=0.1
  Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Installing collected packages: mdurl, markdown-it-py, rich
Successfully installed markdown-it-py-3.0.0 mdurl-0.1.2 rich-13.7.0


In [3]:
import jax
import jax.numpy as jnp

In [4]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [5]:
jax.__version__

'0.4.13'

## Using Tensor Sharding

In [6]:
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding, NamedSharding

In [7]:
from jax import random

In [8]:
import numpy as np

### Dot example

In [9]:
def dot(v1, v2):
  return jnp.vdot(v1, v2)

In [10]:
rng_key = random.PRNGKey(42)

In [11]:
vs = random.normal(rng_key, shape=(8_000,10_000))
v1s = vs[:4_000,:]
v2s = vs[4_000:,:]

v1s.shape, v2s.shape

((4000, 10000), (4000, 10000))

In [12]:
jax.debug.visualize_array_sharding(v1s)

### Positional sharding

In [13]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((8,1)))

In [14]:
sharding

PositionalSharding([[{TPU 0}]
                    [{TPU 1}]
                    [{TPU 2}]
                    [{TPU 3}]
                    [{TPU 6}]
                    [{TPU 7}]
                    [{TPU 4}]
                    [{TPU 5}]])

In [15]:
v1sp = jax.device_put(v1s, sharding)

In [16]:
type(v1sp)

jaxlib.xla_extension.ArrayImpl

In [17]:
jax.debug.visualize_array_sharding(v1sp)

In [18]:
v2sp = jax.device_put(v2s, sharding)

In [19]:
jax.debug.visualize_array_sharding(v2sp)

Input is sharded across all the devices.

In [20]:
d = jax.vmap(dot)(v1sp, v2sp)

In [21]:
d.shape

(4000,)

In [22]:
jax.debug.visualize_array_sharding(d)

In [23]:
%timeit jax.vmap(dot)(v1sp, v2sp).block_until_ready()

1.54 ms ± 26.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [24]:
%timeit jax.vmap(dot)(v1s, v2s).block_until_ready()

2.13 ms ± 51.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
jax.make_jaxpr(jax.vmap(dot))(v1s, v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[4000,10000][39m b[35m:f32[4000,10000][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[4000][39m = dot_general[dimension_numbers=(([1], [1]), ([0], [0]))] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [26]:
modules = jax.jit(jax.vmap(dot)).lower(v1s, v2s).compile().compiler_ir()
for hlo in modules:
  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[4000,10000]{1,0:T(8,128)}, f32[4000,10000]{1,0:T(8,128)})->f32[4000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_output={true}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(256)} parameter(1)
  %scalar_lhs = f32[]{:T(256)} parameter(0)
  ROOT %add = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs, f32[]{:T(256)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[4000,10000], param_1.2: f32[4000,10000]) -> f32[4000] {
  %param_0.2 = f32[4000,10000]{1,0:T(8,128)} parameter(0)
  %param_1.2 = f32[4000,10000]{1,0:T(8,128)} parameter(1)
  %multiply.1 = f32[4000,10000]{1,0:T(8,128)} multiply(f32[4000,10000]{1,0:T(8,128)} %param_0.2, f32[4000,10000]{1,0:T(8,128)} %param_1.2)
  %constant.1 = f32[]{:T(256)} constant(0)
  ROOT %reduce.1 = f32[4000]{0:T(1024)} reduce(f32[4000,10000]{1,0:T(8,128)} %multiply.1, f32[]{:T(256)} %constant.1), dimensions={1}, to_apply=%scalar_add_

In [27]:
jax.make_jaxpr(jax.vmap(dot))(v1sp, v2sp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[4000,10000][39m b[35m:f32[4000,10000][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[4000][39m = dot_general[dimension_numbers=(([1], [1]), ([0], [0]))] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [28]:
modules = jax.jit(jax.vmap(dot)).lower(v1sp, v2sp).compile().compiler_ir()
for hlo in modules:
  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[500,10000]{1,0:T(8,128)}, f32[500,10000]{1,0:T(8,128)})->f32[500]{0:T(512)}}, allow_spmd_sharding_propagation_to_output={true}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(256)} parameter(1)
  %scalar_lhs = f32[]{:T(256)} parameter(0)
  ROOT %add = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs, f32[]{:T(256)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[500,10000], param_1.2: f32[500,10000]) -> f32[500] {
  %param_0.2 = f32[500,10000]{1,0:T(8,128)} parameter(0)
  %param_1.2 = f32[500,10000]{1,0:T(8,128)} parameter(1)
  %multiply.2 = f32[500,10000]{1,0:T(8,128)} multiply(f32[500,10000]{1,0:T(8,128)} %param_0.2, f32[500,10000]{1,0:T(8,128)} %param_1.2)
  %constant.2 = f32[]{:T(256)} constant(0)
  ROOT %reduce.2 = f32[500]{0:T(512)} reduce(f32[500,10000]{1,0:T(8,128)} %multiply.2, f32[]{:T(256)} %constant.2), dimensions={1}, to_apply=%scalar_add_computation, me

### 2D mesh example


In [29]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((2,4)))

In [30]:
v1sp = jax.device_put(v1s, sharding)
v2sp = jax.device_put(v2s, sharding)

In [31]:
jax.debug.visualize_array_sharding(v1sp)

In [32]:
d = jax.vmap(dot)(v1sp, v2sp)
d.shape

(4000,)

In [33]:
jax.debug.visualize_array_sharding(d)

In [34]:
%timeit jax.vmap(dot)(v1sp, v2sp).block_until_ready()

1.61 ms ± 39.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [35]:
%timeit jax.vmap(dot)(v1s, v2s).block_until_ready()

2.32 ms ± 15.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Looking at HLO (note the all-reduce operation)

In [36]:
modules = jax.jit(jax.vmap(dot)).lower(v1sp, v2sp).compile().compiler_ir()
for hlo in modules:
  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[2000,2500]{1,0:T(8,128)}, f32[2000,2500]{1,0:T(8,128)})->f32[2000]{0:T(1024)}}, allow_spmd_sharding_propagation_to_output={true}

%scalar_add_computation (scalar_lhs: f32[], scalar_rhs: f32[]) -> f32[] {
  %scalar_rhs = f32[]{:T(256)} parameter(1)
  %scalar_lhs = f32[]{:T(256)} parameter(0)
  ROOT %add = f32[]{:T(256)} add(f32[]{:T(256)} %scalar_lhs, f32[]{:T(256)} %scalar_rhs)
}

%fused_computation (param_0.2: f32[2000,2500], param_1.2: f32[2000,2500]) -> f32[2000] {
  %param_0.2 = f32[2000,2500]{1,0:T(8,128)} parameter(0)
  %param_1.2 = f32[2000,2500]{1,0:T(8,128)} parameter(1)
  %multiply.2 = f32[2000,2500]{1,0:T(8,128)} multiply(f32[2000,2500]{1,0:T(8,128)} %param_0.2, f32[2000,2500]{1,0:T(8,128)} %param_1.2)
  %constant.2 = f32[]{:T(256)} constant(0)
  ROOT %reduce.2 = f32[2000]{0:T(1024)} reduce(f32[2000,2500]{1,0:T(8,128)} %multiply.2, f32[]{:T(256)} %constant.2), dimensions={1}, to_apply=%scalar_add_computatio

### Using replication

In [37]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((2,4)))

v1sp = jax.device_put(v1s, sharding.replicate(axis=1))

jax.debug.visualize_array_sharding(v1sp)

In [38]:
A = random.normal(rng_key, shape=(10000,2000))
B = random.normal(rng_key, shape=(2000,5000))

In [39]:
Ad = jax.device_put(A, sharding.replicate(1))
Bd = jax.device_put(B, sharding.replicate(0))

In [40]:
jax.debug.visualize_array_sharding(Ad)
jax.debug.visualize_array_sharding(Bd)

In [41]:
Cd = jnp.dot(Ad, Bd)

In [42]:
jax.debug.visualize_array_sharding(Cd)

In [43]:
C = A @ B

In [44]:
jax.numpy.array_equal(C,Cd)

Array(True, dtype=bool)

In [45]:
jax.debug.visualize_array_sharding(C)

In [46]:
C.shape, Cd.shape

((10000, 5000), (10000, 5000))

In [47]:
C[12,3], Cd[12,3]

(Array(43.027637, dtype=float32), Array(43.027637, dtype=float32))

In [48]:
Ca = jnp.dot(A, B)

In [49]:
jax.numpy.array_equal(C,Ca)

Array(True, dtype=bool)

In [50]:
jax.debug.visualize_array_sharding(Ca)

In [51]:
d = (Cd - C)

In [52]:
jnp.max(d), jnp.sum(d)

(Array(0., dtype=float32), Array(0., dtype=float32))

In [53]:
%timeit jnp.dot(Ad, Bd).block_until_ready()

2.12 ms ± 58.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [54]:
%timeit jnp.dot(A, B).block_until_ready()

10.4 ms ± 10.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [56]:
%timeit (A@B).block_until_ready()

10.4 ms ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [57]:
modules = jnp.dot.lower(A, B).compile().compiler_ir()
for hlo in modules:
  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[10000,2000]{0,1:T(8,128)}, f32[2000,5000]{1,0:T(8,128)})->f32[10000,5000]{0,1:T(8,128)}}, allow_spmd_sharding_propagation_to_output={true}

%bitcast_fusion (bf16input: f32[10000,2000]) -> f32[10000,2000] {
  %bf16input = f32[10000,2000]{0,1:T(8,128)} parameter(0)
  ROOT %bitcast = f32[10000,2000]{0,1:T(8,128)} bitcast(f32[10000,2000]{0,1:T(8,128)} %bf16input)
}

%bitcast_fusion.1 (bf16input.1: f32[2000,5000]) -> f32[2000,5000] {
  %bf16input.1 = f32[2000,5000]{1,0:T(8,128)} parameter(0)
  ROOT %bitcast.1 = f32[2000,5000]{1,0:T(8,128)} bitcast(f32[2000,5000]{1,0:T(8,128)} %bf16input.1)
}

%fused_computation (param_0: f32[10000,2000], param_1: f32[2000,5000]) -> f32[10000,5000] {
  %param_0 = f32[10000,2000]{0,1:T(8,128)} parameter(0)
  %fusion.1 = f32[10000,2000]{0,1:T(8,128)} fusion(f32[10000,2000]{0,1:T(8,128)} %param_0), kind=kLoop, calls=%bitcast_fusion
  %param_1 = f32[2000,5000]{1,0:T(8,128)} parameter(1)
  %fusi

In [58]:
modules = jnp.dot.lower(Ad, Bd).compile().compiler_ir()
for hlo in modules:
  print(hlo.to_string())

HloModule jit_dot, is_scheduled=true, entry_computation_layout={(f32[5000,2000]{1,0:T(8,128)}, f32[2000,1250]{1,0:T(8,128)})->f32[5000,1250]{1,0:T(8,128)}}, allow_spmd_sharding_propagation_to_output={true}

%bitcast_fusion (bf16input: f32[5000,2000]) -> f32[5000,2000] {
  %bf16input = f32[5000,2000]{1,0:T(8,128)} parameter(0)
  ROOT %bitcast = f32[5000,2000]{1,0:T(8,128)} bitcast(f32[5000,2000]{1,0:T(8,128)} %bf16input)
}

%bitcast_fusion.1 (bf16input.1: f32[2000,1250]) -> f32[2000,1250] {
  %bf16input.1 = f32[2000,1250]{1,0:T(8,128)} parameter(0)
  ROOT %bitcast.1 = f32[2000,1250]{1,0:T(8,128)} bitcast(f32[2000,1250]{1,0:T(8,128)} %bf16input.1)
}

%fused_computation (param_0: f32[5000,2000], param_1: f32[2000,1250]) -> f32[5000,1250] {
  %param_0 = f32[5000,2000]{1,0:T(8,128)} parameter(0)
  %fusion.1 = f32[5000,2000]{1,0:T(8,128)} fusion(f32[5000,2000]{1,0:T(8,128)} %param_0), kind=kLoop, calls=%bitcast_fusion
  %param_1 = f32[2000,1250]{1,0:T(8,128)} parameter(1)
  %fusion.2 = f32[2

### Using sharding constraints

In [59]:
from jax import jit
from functools import partial

@partial(jax.jit, static_argnums=2)
def distributed_mul(a, b, sharding):
  ad = jax.lax.with_sharding_constraint(a, sharding.replicate(1))
  bd = jax.lax.with_sharding_constraint(b, sharding.replicate(0))
  return jnp.dot(ad, bd)

In [60]:
sharding = PositionalSharding(mesh_utils.create_device_mesh((2,4)))

In [61]:
jax.debug.visualize_array_sharding(A)
jax.debug.visualize_array_sharding(B)

In [62]:
d = distributed_mul(A, B, sharding)

In [63]:
jax.debug.visualize_array_sharding(d)

In [64]:
d.shape

(10000, 5000)

In [65]:
@jit
def nondistributed_mul(a, b):
  return jnp.dot(a, b)

In [66]:
dn = nondistributed_mul(A, B)

In [67]:
jax.debug.visualize_array_sharding(dn)

In [68]:
%timeit distributed_mul(A, B, sharding).block_until_ready()

21.2 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [69]:
%timeit nondistributed_mul(A, B).block_until_ready()

10.4 ms ± 9.82 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [70]:
modules = distributed_mul.lower(A, B, sharding).compile().compiler_ir()
for hlo in modules:
  print(hlo.to_string())

HloModule jit_distributed_mul, is_scheduled=true, entry_computation_layout={(f32[10000,2000]{0,1:T(8,128)}, f32[2000,5000]{1,0:T(8,128)})->f32[5000,1250]{1,0:T(8,128)}}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation.1 (param_0.2: f32[10000,2000], param_1.2: s32[]) -> bf16[5000,2000] {
  %param_0.2 = f32[10000,2000]{0,1:T(8,128)} parameter(0)
  %param_1.2 = s32[]{:T(256)} parameter(1)
  %constant.23 = s32[]{:T(256)} constant(0), metadata={op_name="jit(distributed_mul)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}) resource_env=ResourceEnv(Mesh(device_ids=[], axis_names=()), ()) unconstrained_dims={}]" source_file="/tmp/ipykernel_10444/1742017423.py" source_line=6}
  ROOT %dynamic-slice.8 = bf16[5000,2000]{0,1:T(8,128)(2,1)} dynamic-slice(f32[10000,2000]{0,1:T(8,128)} %param_0.2, s32[]{:T(256)} %param_1.2, s32[]{:T(256)} %constant.23), dynamic_slice_sizes={5000,2000}, metadata={op_name="jit(distributed

### Named sharding

In [71]:
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding

In [72]:
mesh = Mesh(mesh_utils.create_device_mesh((4,2)), axis_names=('batch', 'features'))
sharding = NamedSharding(mesh, P('batch', 'features'))

In [73]:
v1sp = jax.device_put(v1s, sharding)
v2sp = jax.device_put(v2s, sharding)

jax.debug.visualize_array_sharding(v1sp)

In [74]:
d = jax.vmap(dot)(v1sp, v2sp)
d.shape

(4000,)

### Device placement policy and errors

Different devices:

In [75]:
sharding_a = PositionalSharding(np.array(jax.devices()[:4]).reshape(4,1))
sharding_b = PositionalSharding(np.array(jax.devices()[4:]).reshape(4,1))

In [76]:
sharding_a

PositionalSharding([[{TPU 0}]
                    [{TPU 1}]
                    [{TPU 2}]
                    [{TPU 3}]])

In [77]:
sharding_b

PositionalSharding([[{TPU 4}]
                    [{TPU 5}]
                    [{TPU 6}]
                    [{TPU 7}]])

In [78]:
v1sp = jax.device_put(v1s, sharding_a)
v2sp = jax.device_put(v2s, sharding_b)

In [79]:
jax.debug.visualize_array_sharding(v1sp)

In [80]:
jax.debug.visualize_array_sharding(v2sp)

In [81]:
d = jax.vmap(dot)(v1sp, v2sp)

ValueError: Received incompatible devices for jitted computation. Got ARG_SHARDING with device ids [0, 1, 2, 3] on platform TPU and ARG_SHARDING with device ids [4, 5, 6, 7] on platform TPU

Different order:

In [82]:
sharding_a = PositionalSharding(np.array(jax.devices()).reshape(8,1))
sharding_b = PositionalSharding(np.array(jax.devices()[::-1]).reshape(8,1))

In [83]:
sharding_a, sharding_b

(PositionalSharding([[{TPU 0}]
                     [{TPU 1}]
                     [{TPU 2}]
                     [{TPU 3}]
                     [{TPU 4}]
                     [{TPU 5}]
                     [{TPU 6}]
                     [{TPU 7}]]),
 PositionalSharding([[{TPU 7}]
                     [{TPU 6}]
                     [{TPU 5}]
                     [{TPU 4}]
                     [{TPU 3}]
                     [{TPU 2}]
                     [{TPU 1}]
                     [{TPU 0}]]))

In [84]:
v1sp = jax.device_put(v1s, sharding_a)
v2sp = jax.device_put(v2s, sharding_b)

In [85]:
jax.debug.visualize_array_sharding(v1sp)

In [86]:
jax.debug.visualize_array_sharding(v2sp)

In [87]:
d = jax.vmap(dot)(v1sp, v2sp)

ValueError: Received incompatible devices for jitted computation. Got ARG_SHARDING with device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and ARG_SHARDING with device ids [7, 6, 5, 4, 3, 2, 1, 0] on platform TPU

In [88]:
d = jax.vmap(dot)(v1sp, v2s)

In [89]:
jax.debug.visualize_array_sharding(d)

## MLP example

### Preparing data

Install these modules if you created a new empty cloud machine

In [169]:
!pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.13.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (479.6 MB)
[K     |████████████████████████████████| 479.6 MB 10 kB/s 
[?25hCollecting absl-py>=1.0.0
  Downloading absl_py-2.0.0-py3-none-any.whl (130 kB)
[K     |████████████████████████████████| 130 kB 12.4 MB/s 
[?25hCollecting astunparse>=1.6.0
  Downloading astunparse-1.6.3-py2.py3-none-any.whl (12 kB)
Collecting flatbuffers>=23.1.21
  Downloading flatbuffers-23.5.26-py2.py3-none-any.whl (26 kB)
Collecting gast<=0.4.0,>=0.2.1
  Downloading gast-0.4.0-py3-none-any.whl (9.8 kB)
Collecting google-pasta>=0.1.1
  Downloading google_pasta-0.2.0-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 6.2 MB/s 
[?25hCollecting grpcio<2.0,>=1.24.3
  Downloading grpcio-1.60.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)
[K     |████████████████████████████████| 5.4 MB 12.7 MB/s 
[?25hCollecting h5py>=2.9.0
  Downloading h5py-3.10.0-cp3

In [170]:
!pip install tensorflow_datasets

Collecting tensorflow_datasets
  Downloading tensorflow_datasets-4.9.2-py3-none-any.whl (5.4 MB)
[K     |████████████████████████████████| 5.4 MB 1.6 MB/s 
Collecting array-record
  Downloading array_record-0.4.0-py38-none-any.whl (3.0 MB)
[K     |████████████████████████████████| 3.0 MB 94.6 MB/s 
Collecting dm-tree
  Downloading dm_tree-0.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB)
[K     |████████████████████████████████| 152 kB 84.1 MB/s 
[?25hCollecting etils[enp,epath]>=0.9.0
  Downloading etils-1.3.0-py3-none-any.whl (126 kB)
[K     |████████████████████████████████| 126 kB 67.4 MB/s 
Collecting promise
  Downloading promise-2.3.tar.gz (19 kB)
Collecting tensorflow-metadata
  Downloading tensorflow_metadata-1.14.0-py3-none-any.whl (28 kB)
Collecting toml
  Downloading toml-0.10.2-py2.py3-none-any.whl (16 kB)
Collecting tqdm
  Downloading tqdm-4.66.1-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 9.0 MB/s 
Collecting googl

In [171]:
import jax
import tensorflow as tf
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

data, info = tfds.load(name="mnist",
                       data_dir=data_dir,
                       as_supervised=True,
                       with_info=True)

data_train = data['train']
data_test  = data['test']

  from .autonotebook import tqdm as notebook_tqdm


[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/tfds/mnist/3.0.1...[0m


Dl Completed...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13.12 file/s]

[1mDataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.[0m





In [172]:
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes
BATCH_SIZE  = 32 # total 60k samples
NUM_DEVICES = jax.device_count()

In [173]:
def preprocess(img, label):
  """Resize and preprocess images."""
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(
    data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data  = tfds.as_numpy(
    data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)

In [174]:
len(train_data)

235

### Preparing MLP

In [175]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot

In [176]:
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01

In [177]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [178]:
def predict(params, image):
  """Function for per-example predictions."""
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))

### Loss and update functions

In [179]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS  = 20

In [180]:
def loss(params, images, targets):
  """Categorical cross entropy loss function."""
  logits = batched_predict(params, images)
  log_preds = logits - logsumexp(logits) # logsumexp trick https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
  return -jnp.mean(targets*log_preds)

@jit
def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

### Training loop

In [181]:
@jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

#### 8-way data parallelism

In [182]:
sharding = PositionalSharding(jax.devices()).reshape(8, 1)

In [183]:
import time

params = init_params
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = one_hot(y, NUM_LABELS)
    x = jax.device_put(x, sharding)
    y = jax.device_put(y, sharding)
    params = jax.device_put(params, sharding.replicate())
    params, loss_value = update(params, x, y, epoch)
    losses.append(jnp.sum(loss_value))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 1.91 sec
Training set loss 0.6914460062980652
Training set accuracy 0.8648548126220703
Test set accuracy 0.874804675579071
Epoch 1 in 1.00 sec
Training set loss 0.6259850859642029
Training set accuracy 0.8882258534431458
Test set accuracy 0.893847644329071
Epoch 2 in 0.99 sec
Training set loss 0.6164931058883667
Training set accuracy 0.8978224396705627
Test set accuracy 0.901074230670929
Epoch 3 in 1.00 sec
Training set loss 0.6100647449493408
Training set accuracy 0.906527042388916
Test set accuracy 0.908984363079071
Epoch 4 in 0.97 sec
Training set loss 0.6044379472732544
Training set accuracy 0.9142231941223145
Test set accuracy 0.917285144329071
Epoch 5 in 1.02 sec
Training set loss 0.5995732545852661
Training set accuracy 0.9211103320121765
Test set accuracy 0.9248046875
Epoch 6 in 1.02 sec
Training set loss 0.5955986976623535
Training set accuracy 0.9261746406555176
Test set accuracy 0.929394543170929
Epoch 7 in 1.01 sec
Training set loss 0.5924407839775085
Training se

In [184]:
jax.debug.visualize_array_sharding(params[0][0])

#### 4-way data parallelism, 2-way tensor parallelism

In [185]:
sharding = PositionalSharding(jax.devices()).reshape(4, 2)

In [186]:
LAYER_SIZES = [28*28, 10000, 10000, 10]
PARAM_SCALE = 0.01

def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
  """Initialize all layers for a fully-connected neural network with given sizes"""

  def random_layer_params(m, n, key, scale=1e-2):
    """A helper function to randomly initialize weights and biases of a dense layer"""
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

In [187]:
sharded_params = []
for i,(w,b) in enumerate(init_params):
  print(i, w.shape, b.shape)
  if i==0:
    w = jax.device_put(w, sharding.replicate(0))
    b = jax.device_put(b, sharding.replicate(0))
  elif i==1:
    w = jax.device_put(w, sharding.replicate(0))
    b = jax.device_put(b, sharding.replicate(0))
  elif i==2:
    w = jax.device_put(w, sharding.replicate())
    b = jax.device_put(b, sharding.replicate())
  sharded_params.append((w,b))


0 (10000, 784) (10000,)
1 (10000, 10000) (10000,)
2 (10, 10000) (10,)


In [188]:
for (w,b) in init_params:
  jax.debug.visualize_array_sharding(w)
  jax.debug.visualize_array_sharding(b)

In [189]:
for (w,b) in init_params:
  jax.debug.visualize_array_sharding(jax.device_put(w, sharding.replicate()))
  jax.debug.visualize_array_sharding(jax.device_put(b, sharding.replicate()))

In [190]:
for (w,b) in sharded_params:
  jax.debug.visualize_array_sharding(w)
  jax.debug.visualize_array_sharding(b)

In [191]:
import time

params = sharded_params
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = one_hot(y, NUM_LABELS)
    x = jax.device_put(x, sharding.replicate(1))
    y = jax.device_put(y, sharding.replicate(1))
    #params = jax.device_put(params, sharding.replicate())
    params, loss_value = update(params, x, y, epoch)
    losses.append(jnp.sum(loss_value))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

  jax.debug.visualize_array_sharding(params[0][0])

Epoch 0 in 6.61 sec
Training set loss 0.6637932062149048
Training set accuracy 0.8757756948471069
Test set accuracy 0.8863281607627869


Epoch 1 in 3.96 sec
Training set loss 0.6224857568740845
Training set accuracy 0.8902149200439453
Test set accuracy 0.8956055045127869


Epoch 2 in 3.95 sec
Training set loss 0.6166571378707886
Training set accuracy 0.8960660099983215
Test set accuracy 0.900097668170929


Epoch 3 in 3.96 sec
Training set loss 0.613541841506958
Training set accuracy 0.8999667167663574
Test set accuracy 0.9034180045127869


Epoch 4 in 3.97 sec
Training set loss 0.6112023591995239
Training set accuracy 0.9033576846122742
Test set accuracy 0.907421886920929


Epoch 5 in 3.95 sec
Training set loss 0.6091800332069397
Training set accuracy 0.9063829779624939
Test set accuracy 0.909375011920929


Epoch 6 in 3.95 sec
Training set loss 0.6072940230369568
Training set accuracy 0.9090425372123718
Test set accuracy 0.911816418170929


Epoch 7 in 3.95 sec
Training set loss 0.6054378747940063
Training set accuracy 0.911436140537262
Test set accuracy 0.9150390625


Epoch 8 in 3.95 sec
Training set loss 0.6035282611846924
Training set accuracy 0.9151983261108398
Test set accuracy 0.918749988079071


Epoch 9 in 3.94 sec
Training set loss 0.6015031337738037
Training set accuracy 0.9181182980537415
Test set accuracy 0.921191394329071


Epoch 10 in 3.94 sec
Training set loss 0.5993688702583313
Training set accuracy 0.9215425252914429
Test set accuracy 0.9237304925918579


Epoch 11 in 3.93 sec
Training set loss 0.5972492694854736
Training set accuracy 0.9245678186416626
Test set accuracy 0.9267578125


Epoch 12 in 3.94 sec
Training set loss 0.5953165888786316
Training set accuracy 0.9276872873306274
Test set accuracy 0.9283203482627869


Epoch 13 in 3.96 sec
Training set loss 0.5936313271522522
Training set accuracy 0.9301806092262268
Test set accuracy 0.930371105670929


Epoch 14 in 3.95 sec
Training set loss 0.5921530723571777
Training set accuracy 0.932541012763977
Test set accuracy 0.932812511920929


Epoch 15 in 3.96 sec
Training set loss 0.590818464756012
Training set accuracy 0.9347129464149475
Test set accuracy 0.934374988079071


Epoch 16 in 3.96 sec
Training set loss 0.5895960330963135
Training set accuracy 0.9365746378898621
Test set accuracy 0.936328113079071


Epoch 17 in 3.93 sec
Training set loss 0.5884610414505005
Training set accuracy 0.9382202625274658
Test set accuracy 0.9383789300918579


Epoch 18 in 3.93 sec
Training set loss 0.5874113440513611
Training set accuracy 0.9402315616607666
Test set accuracy 0.9404296875


Epoch 19 in 3.93 sec
Training set loss 0.5864334106445312
Training set accuracy 0.9416278004646301
Test set accuracy 0.941699206829071


**!!! Do not forget to shutdown your Cloud TPU, or you'll spend much money on it!!!**