In [1]:
import os
os.environ['XLA_HLO_DEBUG'] = '1'
if not os.environ.get('XRT_TPU_CONFIG'):
  os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'

os.environ.get('XRT_TPU_CONFIG')

'localservice;0;localhost:51011'

## XlaBuilder Playground
* [x] Simple Add operation with fwd/bwd to get familiar with `torch_xla.core.xla_builder` primitives.

In [2]:
import torch
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

In [3]:
def _add_forward(x, y):
  # type(x) = torch_xla.core.xla_builder.Op
  # type(x.op) = _XLAC.XlaOp = op_builder::Op
  builder = torch_xla._XLAC._xla_op_builder(x.op)
  return xb.Op(torch_xla._XLAC._xla_op_create(builder, 'Add', [x.op, y.op], {}))

ADD_FORWARD = xor.register('AddForward', _add_forward)

In [4]:
class Add(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x, y):
    ctx.shape = x.shape
    ctx.device = x.device
    return ADD_FORWARD(x, y)

  @staticmethod
  def backward(ctx, grad_output):
    grad = torch.ones(ctx.shape, device=ctx.device)
    return grad, grad


In [5]:
device = xm.xla_device()
x = torch.ones(1, 1, device=device, requires_grad=True)
y = torch.randn(1, 1, device=device, requires_grad=True)

output = Add.apply(x, y)
# output = ADD_FORWARD(x, y)
loss = output.sum()
loss.backward()

print(f'x: {x}')
print(f'x.grad: {x.grad}')
print(f'y: {y}')
print(f'y.grad: {y.grad}')
print(f'output: {output}')

x: tensor([[1.]], device='xla:1', requires_grad=True)
x.grad: tensor([[1.]], device='xla:1')
y: tensor([[0.1028]], device='xla:1', requires_grad=True)
y.grad: tensor([[1.]], device='xla:1')
output: tensor([[1.1028]], device='xla:1', grad_fn=<AddBackward>)


In [6]:
# Comparing result against native autograd
t1 = torch.rand((1,1), requires_grad=True)
t2 = torch.rand((1,1), requires_grad=True)

s = t1 + t2
s.sum().backward()

print(t1)
print(t1.grad)
print(t2)
print(t2.grad)
print(s)

tensor([[0.5333]], requires_grad=True)
tensor([[1.]])
tensor([[0.6611]], requires_grad=True)
tensor([[1.]])
tensor([[1.1944]], grad_fn=<AddBackward0>)


## PyTorch / XLA for tracing, lowering -> JAX PjRt Runtime

In [7]:
import jax
from jax.lib import xla_client as xc
import numpy as np

In [8]:
two = (2 * torch.ones((1,1))).to(device)
three = (3 * torch.ones((1,1))).to(device)
six = two * three

print(torch_xla._XLAC._get_xla_tensors_hlo([six]))

HloModule IrToHlo.11

ENTRY %IrToHlo.11 (p0.1: f32[1,1], p1.5: f32[1,1]) -> (f32[1,1]) {
  %constant.2 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %reshape.3 = f32[1,1]{1,0} reshape(f32[] %constant.2), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.4 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.3), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %constant.6 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %reshape.7 = f32[1,1]{1,0} reshape(f32[] %constant.6), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.8 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.7), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %p1.5 = f32[1,1]{1,0} parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %p0.1 = f32[1,1]{1,0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data"}

In [9]:
backend = xc.get_local_backend()

hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([six])
# JAX XlaComputation
computation = xc.XlaComputation(
    torch_xla._XLAC._hlo_text_to_serialized_xla_computation(hlo_text))
# print(hlo_text)
print(computation.as_hlo_text())

compiled_computation = backend.compile(computation)
compiled_computation.local_devices()

# host_input = [t1.detach.numpy()]
# # place host variable on device and execute
# # device_input = backend.buffer_from_pyval(host_input)
# compiled_computation.execute([host_input, ])

HloModule IrToHlo.11

ENTRY IrToHlo.11 {
  constant.2 = f32[] constant(0)
  reshape.3 = f32[1,1]{1,0} reshape(constant.2)
  broadcast.4 = f32[1,1]{1,0} broadcast(reshape.3), dimensions={0,1}
  constant.6 = f32[] constant(0)
  reshape.7 = f32[1,1]{1,0} reshape(constant.6)
  broadcast.8 = f32[1,1]{1,0} broadcast(reshape.7), dimensions={0,1}
  p1.5 = f32[1,1]{1,0} parameter(1)
  p0.1 = f32[1,1]{1,0} parameter(0)
  multiply.9 = f32[1,1]{1,0} multiply(p1.5, p0.1)
  ROOT tuple.10 = (f32[1,1]{1,0}) tuple(multiply.9)
}




[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]

In [10]:
host_x = np.array([[3.0]], dtype=np.float32)
device_x = backend.buffer_from_pyval(host_x)
host_y = np.array([[4.0]], dtype=np.float32)
device_y = backend.buffer_from_pyval(host_y)

compiled_computation.execute([device_x, device_y])

[DeviceArray([[12.]], dtype=float32)]

## SPMD Playground
Changes needed:
* [x] XLA:HLO - Need to call [XlaBuilder::SetSharding](https://github.com/tensorflow/tensorflow/blob/7a18c91de6272c93468f9987d02a480a61a4b38c/tensorflow/compiler/xla/client/xla_builder.h#L195) with sharding annotation.
* [x] Runtime - Once we have HLO dump generated by PyTorch / XLA, we'll leverage JAX runtime interface to compile and execute the SPMD HLO computation, given that we plan on migrating over to PjRt runtime, which has many bits currently missing in XRT that are needed for SPMD.
  * [x] Load XLA HLO text as a `xla::XlaComputation`

We use JAX to do the following:
  * [x] Build options: (1) Set proper device assignment, (2) UseSpmdPartitioning set on compilation options.
  * [x] Execution run on all cores.

Testing:
* [x] HLO Graph dump validation on sharding annotations.
* [x] Dump from post SPMD partitioning pass (collectives insertion).
* [x] Execution result validation.
* [x] Multi-core concurrent execution full-traces sanity check.


## Sharded PyTorch / XLA HLO -> PjRt Runtime

In [11]:
from jax.experimental import sharded_jit
from jax.experimental import PartitionSpec as P
from jax._src.util import prod
from jax.lib import xla_bridge as jxb

In [12]:
s = jax.profiler.start_server(9012)

In [13]:
# Sharded tensor matmul
shape = (2,2)
p = P(*shape)
py_opsharding = jxb._sharding_to_proto(p)


def _sharded_mm(x, y):
  builder = torch_xla._XLAC._xla_op_builder(x.op)

  # Set sharding on PT / XLA Builder.
  torch_xla._XLAC._xla_builder_set_sharding(
      builder,
      py_opsharding.replicate_on_last_tile_dim,
      py_opsharding.tile_assignment_devices,
      py_opsharding.tile_assignment_dimensions,
      # py_opsharding.tuple_shardings,
      int(py_opsharding.type))

  # Make 'Sharding' XLA custom call.
  x_sharded = x.custom_call('Sharding')
  y_sharded = y.custom_call('Sharding')
  return x_sharded @ y_sharded


SHARDED_MM = xor.register('ShardedMatMul', _sharded_mm)


# Dump HLO for execution on JAX PjRt.
two = (2 * torch.ones((shape[0]*4096, shape[1]*4096))).to(device)
three = (3 * torch.ones((shape[0]*4096, shape[1]*4096))).to(device)

res = SHARDED_MM(two, three)
sharded_hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([res])
print(sharded_hlo_text)


# Execute sharded HLO on PjRt.
computation = xc.XlaComputation(
    torch_xla._XLAC._hlo_text_to_serialized_xla_computation(sharded_hlo_text))

nrep = 1
nparts = sum(shape) 
devices = jxb.local_devices()[:nparts]
device_assignment = np.array([[d.id for d in devices]])
device_assignment = np.reshape(device_assignment, (-1, nparts))

compiled_computation = backend.compile(
    computation,
    jxb.get_compile_options(nrep, nparts, device_assignment))


# Allocate device data.
host_x = 3 * np.ones((4096,4096), dtype=np.float32)
host_y = 4 * np.ones((4096,4096), dtype=np.float32)
devices_x = [backend.buffer_from_pyval(host_x, device=device) for device in devices]
devices_y = [backend.buffer_from_pyval(host_y, device=device) for device in devices]

for step in range(3000):
  if step % 100 == 0:
    print(step)
  with jax.profiler.StepTraceAnnotation("step", step_num=step):
    output = compiled_computation.execute_sharded_on_local_devices([devices_x, devices_y])

HloModule IrToHlo.21

%ShardedMatMul.13 (p0.14: f32[8192,8192], p1.15: f32[8192,8192]) -> f32[8192,8192] {
  %p0.14 = f32[8192,8192]{1,0} parameter(0)
  %custom-call.16 = f32[8192,8192]{1,0} custom-call(f32[8192,8192]{1,0} %p0.14), custom_call_target="Sharding", sharding={devices=[2,2]0,1,2,3}
  %p1.15 = f32[8192,8192]{1,0} parameter(1)
  %custom-call.17 = f32[8192,8192]{1,0} custom-call(f32[8192,8192]{1,0} %p1.15), custom_call_target="Sharding", sharding={devices=[2,2]0,1,2,3}
  ROOT %dot.18 = f32[8192,8192]{1,0} dot(f32[8192,8192]{1,0} %custom-call.16, f32[8192,8192]{1,0} %custom-call.17), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,2]0,1,2,3}
}

ENTRY %IrToHlo.21 (p0.1: f32[8192,8192], p1.7: f32[8192,8192]) -> (f32[8192,8192]) {
  %constant.2 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %reshape.3 = f32[1,1]{1,0} reshape(f32[] %constant.2), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.4 = f3

In [14]:
p = P(2,4)
py_opsharding = jxb._sharding_to_proto(p)


def _sharded_mul(x, y):
  builder = torch_xla._XLAC._xla_op_builder(x.op)

  # Set sharding on PT / XLA Builder.
  torch_xla._XLAC._xla_builder_set_sharding(
      builder,
      py_opsharding.replicate_on_last_tile_dim,
      py_opsharding.tile_assignment_devices,
      py_opsharding.tile_assignment_dimensions,
      # py_opsharding.tuple_shardings,
      int(py_opsharding.type))

  # Make 'Sharding' XLA custom call.
  x_sharded = x.custom_call('Sharding')
  y_sharded = y.custom_call('Sharding')
  return x_sharded * y_sharded

SHARDED_MULTIPLY = xor.register('ShardedMultiply', _sharded_mul)


# Dump HLO for execution on JAX PjRt.
two = (2 * torch.ones((2*4096,4*4096))).to(device)
three = (3 * torch.ones((2*4096,4*4096))).to(device)

six = SHARDED_MULTIPLY(two, three)
sharded_hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([six])
print(sharded_hlo_text)


# Execute sharded HLO on PjRt.
computation = xc.XlaComputation(
    torch_xla._XLAC._hlo_text_to_serialized_xla_computation(sharded_hlo_text))

nrep = 1
nparts = 8
devices = jxb.local_devices()[:nparts]
device_assignment = np.array([[d.id for d in devices]])
device_assignment = np.reshape(device_assignment, (-1, nparts))

compiled_computation = backend.compile(
    computation,
    jxb.get_compile_options(nrep, nparts, device_assignment))


# Allocate device data.
host_x = 3 * np.ones((4096,4096), dtype=np.float32)
host_y = 4 * np.ones((4096,4096), dtype=np.float32)
devices_x = [backend.buffer_from_pyval(host_x, device=device) for device in devices]
devices_y = [backend.buffer_from_pyval(host_y, device=device) for device in devices]

for step in range(300):
  if step % 100 == 0:
    print(step)
  with jax.profiler.StepTraceAnnotation("step", step_num=step):
    output = compiled_computation.execute_sharded_on_local_devices([devices_x, devices_y])
# print(output)
# [(res.shape, res.device()) for res in output[0]]

HloModule IrToHlo.21

%ShardedMultiply.13 (p0.14: f32[8192,16384], p1.15: f32[8192,16384]) -> f32[8192,16384] {
  %p0.14 = f32[8192,16384]{1,0} parameter(0)
  %custom-call.16 = f32[8192,16384]{1,0} custom-call(f32[8192,16384]{1,0} %p0.14), custom_call_target="Sharding", sharding={devices=[2,4]0,1,2,3,4,5,6,7}
  %p1.15 = f32[8192,16384]{1,0} parameter(1)
  %custom-call.17 = f32[8192,16384]{1,0} custom-call(f32[8192,16384]{1,0} %p1.15), custom_call_target="Sharding", sharding={devices=[2,4]0,1,2,3,4,5,6,7}
  ROOT %multiply.18 = f32[8192,16384]{1,0} multiply(f32[8192,16384]{1,0} %custom-call.16, f32[8192,16384]{1,0} %custom-call.17), sharding={devices=[2,4]0,1,2,3,4,5,6,7}
}

ENTRY %IrToHlo.21 (p0.1: f32[8192,16384], p1.7: f32[8192,16384]) -> (f32[8192,16384]) {
  %constant.2 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %reshape.3 = f32[1,1]{1,0} reshape(f32[] %constant.2), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.4 = f32

In [15]:
# # Manually writing HLO graph

# sharded_hlo_text='''
# HloModule IrToHlo.19

# %AddForward.13 (p0.14: f32[2,2], p1.15: f32[2,2]) -> f32[2,2] {
#   %p0.14 = f32[2,2]{1,0} parameter(0)
#   %p1.15 = f32[2,2]{1,0} parameter(1)
#   ROOT %add.16 = f32[2,2]{1,0} add(f32[2,2]{1,0} %p0.14, f32[2,2]{1,0} %p1.15)
# }

# ENTRY %IrToHlo.19 (p0.1: f32[2,2], p1.7: f32[2,2]) -> (f32[2,2]) {
#   %constant.2 = f32[] constant(0)
#   %reshape.3 = f32[1,1]{1,0} reshape(f32[] %constant.2)
#   %broadcast.4 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.3), dimensions={0,1}
#   %reshape.5 = f32[] reshape(f32[1,1]{1,0} %broadcast.4)
#   %broadcast.6 = f32[2,2]{1,0} broadcast(f32[] %reshape.5), dimensions={}
#   %constant.8 = f32[] constant(0)
#   %reshape.9 = f32[1,1]{1,0} reshape(f32[] %constant.8)
#   %broadcast.10 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.9), dimensions={0,1}
#   %reshape.11 = f32[] reshape(f32[1,1]{1,0} %broadcast.10)
#   %broadcast.12 = f32[2,2]{1,0} broadcast(f32[] %reshape.11), dimensions={}
#   %p1.7 = f32[2,2]{1,0} parameter(1)
#   %custom-call.1 = f32[2,2]{1,0} custom-call(%p1.7), custom_call_target="Sharding", sharding={devices=[2,1]0,1}
#   %p0.1 = f32[2,2]{1,0} parameter(0)
#   %custom-call.0 = f32[2,2]{1,0} custom-call(%p0.1), custom_call_target="Sharding", sharding={devices=[2,1]0,1}
#   %call.17 = f32[2,2]{1,0} call(f32[2,2]{1,0} %custom-call.1, f32[2,2]{1,0} %custom-call.0), to_apply=%AddForward.13
#   ROOT %tuple.18 = (f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %call.17)
# }
# '''

In [16]:
p = P(2,1)
py_opsharding = jxb._sharding_to_proto(p)


def _sharded_add_forward_with_custom_call(x, y):
  # type(x) = torch_xla.core.xla_builder.Op
  # type(x.op) = _XLAC.XlaOp = op_builder::Op
  builder = torch_xla._XLAC._xla_op_builder(x.op)

  # Set sharding on PT / XLA Builder.
  torch_xla._XLAC._xla_builder_set_sharding(
      builder,
      py_opsharding.replicate_on_last_tile_dim,
      py_opsharding.tile_assignment_devices,
      py_opsharding.tile_assignment_dimensions,
      # py_opsharding.tuple_shardings,
      int(py_opsharding.type))

  # Make 'Sharding' XLA custom call.
  x_sharded = x.custom_call('Sharding')
  y_sharded = y.custom_call('Sharding')

  # return xb.Op(torch_xla._XLAC._xla_op_create(
  #     builder, 'Add', [x_sharded.op, y_sharded.op], {}))
  return x_sharded + y_sharded


SHARDED_ADD_FORWARD_WITH_CUSTOM_CALL = xor.register(
    'ShardedAddForwardWithCustomCall', _sharded_add_forward_with_custom_call)


# Dump HLO for execution on JAX PjRt.
two = (2 * torch.ones((2,2))).to(device)
three = (3 * torch.ones((2,2))).to(device)

five = SHARDED_ADD_FORWARD_WITH_CUSTOM_CALL(two, three)
sharded_hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([five])
print(sharded_hlo_text)


# Execute the dumped HLO graph in JAX.
computation = xc.XlaComputation(
    torch_xla._XLAC._hlo_text_to_serialized_xla_computation(sharded_hlo_text))
# print(computation.as_hlo_text())

nrep = 1
nparts = 2
devices = jxb.local_devices()[:nparts]
device_assignment = np.array([[d.id for d in devices]])
device_assignment = np.reshape(device_assignment, (-1, nparts))

compiled_computation = backend.compile(
    computation,
    jxb.get_compile_options(nrep, nparts, device_assignment))


# Allocate device data.
host_x = 3 * np.ones((1,2), dtype=np.float32)
device_x_a = backend.buffer_from_pyval(host_x, device=devices[0])
device_x_b = backend.buffer_from_pyval(host_x, device=devices[1])
host_y = 4 * np.ones((1,2), dtype=np.float32)
device_y_a = backend.buffer_from_pyval(host_y, device=devices[0])
device_y_b = backend.buffer_from_pyval(host_y, device=devices[1])

compiled_computation.execute_sharded_on_local_devices(
    [[device_x_a, device_x_b], [device_y_a, device_y_b]])

HloModule IrToHlo.21

%ShardedAddForwardWithCustomCall.13 (p0.14: f32[2,2], p1.15: f32[2,2]) -> f32[2,2] {
  %p0.14 = f32[2,2]{1,0} parameter(0)
  %custom-call.16 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %p0.14), custom_call_target="Sharding", sharding={devices=[2,1]0,1}
  %p1.15 = f32[2,2]{1,0} parameter(1)
  %custom-call.17 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %p1.15), custom_call_target="Sharding", sharding={devices=[2,1]0,1}
  ROOT %add.18 = f32[2,2]{1,0} add(f32[2,2]{1,0} %custom-call.16, f32[2,2]{1,0} %custom-call.17), sharding={devices=[2,1]0,1}
}

ENTRY %IrToHlo.21 (p0.1: f32[2,2], p1.7: f32[2,2]) -> (f32[2,2]) {
  %constant.2 = f32[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant"}
  %reshape.3 = f32[1,1]{1,0} reshape(f32[] %constant.2), metadata={op_type="aten__expand" op_name="aten__expand"}
  %broadcast.4 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.3), dimensions={0,1}, metadata={op_type="aten__expand" op_name="aten__expand"}
  %reshape.5

[[DeviceArray([[7., 7.],
               [7., 7.]], dtype=float32),
  DeviceArray([[7., 7.],
               [7., 7.]], dtype=float32)]]

#### Sample SPMD parititoner pass runtime logs

```
I0615 14:52:19.273297 1044120 2a886c8_compiler_base.cc:3450] XLA::TPU running hlo passes for 19 instructions, modules: IrToHlo.21
I0615 14:52:19.273737 1044120 2a886c8_compiler_base.cc:3517] HLO optimizing 9 instructions
I0615 14:52:19.273757 1044120 2a886c8_compiler_base.cc:3532] XLA::TPU HLO optimization
I0615 14:52:19.274559 1044120 spmd_partitioner.cc:3533]
I0615 14:52:19.274578 1044120 spmd_partitioner.cc:3533]
I0615 14:52:19.274580 1044120 spmd_partitioner.cc:3533] ***** SPMD memory usage before partition *****
I0615 14:52:19.274582 1044120 spmd_partitioner.cc:3533]
I0615 14:52:19.274584 1044120 spmd_partitioner.cc:3533]   ** Replicated instructions
I0615 14:52:19.274585 1044120 spmd_partitioner.cc:3533]
I0615 14:52:19.274587 1044120 spmd_partitioner.cc:3533]   ** All instructions
I0615 14:52:19.274589 1044120 spmd_partitioner.cc:3533]   256.00MiB : %p1.7 = f32[8192,8192]{1,0} parameter(1), sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
I0615 14:52:19.274591 1044120 spmd_partitioner.cc:3533]   256.00MiB : %p0.1 = f32[8192,8192]{1,0} parameter(0), sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
I0615 14:52:19.274593 1044120 spmd_partitioner.cc:3533]   256.00MiB : %convolution = f32[8192,8192]{1,0} convolution(f32[8192,8192]{1,0} %p1.7, f32[8192,8192]{1,0} %p0.1), dim_labels=bf_io->bf, sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.274640 1044120 spmd_partitioner.cc:3242] Partitioning computation IrToHlo.21 for 1 replicas and 4 partitions
I0615 14:52:19.275185 1044120 spmd_partitioner.cc:373] Resharding %param = f32[4096,4096]{1,0} parameter(1), sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"} from {devices=[2,2]0,1,2,3} to {devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
I0615 14:52:19.275408 1044120 spmd_partitioner.cc:373] Resharding %param = f32[4096,4096]{1,0} parameter(0), sharding={devices=[2,1]0,1}, metadata={op_type="xla__device_data" op_name="xla__device_data"} from {devices=[2,1]0,1} to {replicated}
I0615 14:52:19.275618 1044120 spmd_partitioner.cc:373] Resharding %convolution = f32[4096,4096]{1,0} convolution(f32[4096,8192]{1,0} %reshape, f32[8192,4096]{1,0} %reshape), dim_labels=bf_io->bf, sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"} from {devices=[2,2]0,1,2,3} to {replicated}
I0615 14:52:19.275768 1044120 spmd_partitioner.cc:3595]
I0615 14:52:19.275778 1044120 spmd_partitioner.cc:3595]
I0615 14:52:19.275780 1044120 spmd_partitioner.cc:3595] ***** SPMD memory usage after partition *****
I0615 14:52:19.275781 1044120 spmd_partitioner.cc:3595]   256.00MiB : %all-gather.2 = f32[4,4096,4096]{2,1,0} all-gather(f32[1,4096,4096]{2,1,0} %reshape.10), channel_id=3, replica_groups={{0,1,2,3}}, dimensions={0}, use_global_device_ids=true
I0615 14:52:19.275787 1044120 spmd_partitioner.cc:3595]   256.00MiB : %reshape.12 = f32[2,2,4096,4096]{3,2,1,0} reshape(f32[4,4096,4096]{2,1,0} %all-gather.2)
I0615 14:52:19.275788 1044120 spmd_partitioner.cc:3595]   256.00MiB : %transpose.2 = f32[2,4096,2,4096]{3,1,2,0} transpose(f32[2,2,4096,4096]{3,2,1,0} %reshape.12), dimensions={0,2,1,3}
I0615 14:52:19.275790 1044120 spmd_partitioner.cc:3595]   256.00MiB : %reshape.13 = f32[8192,8192]{1,0} reshape(f32[2,4096,2,4096]{3,1,2,0} %transpose.2), sharding={replicated}
I0615 14:52:19.275792 1044120 spmd_partitioner.cc:3595]   128.00MiB : %all-gather = f32[2,4096,4096]{2,1,0} all-gather(f32[1,4096,4096]{2,1,0} %reshape.4), channel_id=1, replica_groups={{0,1},{2,3}}, dimensions={0}, use_global_device_ids=true, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275802 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.275806 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.275807 1044120 spmd_partitioner.cc:3596] ***** SPMD memory during transformation *****
I0615 14:52:19.275809 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.275811 1044120 spmd_partitioner.cc:3596]   256.00MiB : %tuple = (f32[8192,8192]{1,0}) tuple(f32[8192,8192]{1,0} %reshape), sharding={{replicated}}     * %reshape = f32[1,4096,4096]{2,1,0} reshape(f32[4096,4096]{1,0} %convolution)
I0615 14:52:19.275812 1044120 spmd_partitioner.cc:3596]      * %all-gather = f32[4,4096,4096]{2,1,0} all-gather(f32[1,4096,4096]{2,1,0} %reshape), channel_id=3, replica_groups={{0,1,2,3}}, dimensions={0}, use_global_device_ids=true
I0615 14:52:19.275814 1044120 spmd_partitioner.cc:3596]      * %reshape = f32[2,2,4096,4096]{3,2,1,0} reshape(f32[4,4096,4096]{2,1,0} %all-gather)
I0615 14:52:19.275816 1044120 spmd_partitioner.cc:3596]      * %transpose = f32[2,4096,2,4096]{3,1,2,0} transpose(f32[2,2,4096,4096]{3,2,1,0} %reshape), dimensions={0,2,1,3}
I0615 14:52:19.275817 1044120 spmd_partitioner.cc:3596]      * %reshape = f32[8192,8192]{1,0} reshape(f32[2,4096,2,4096]{3,1,2,0} %transpose), sharding={replicated}
I0615 14:52:19.275819 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.275820 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.275822 1044120 spmd_partitioner.cc:3596]   128.00MiB : %convolution = f32[4096,4096]{1,0} convolution(f32[4096,8192]{1,0} %reshape, f32[8192,4096]{1,0} %reshape), dim_labels=bf_io->bf, sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}     * %constant = u32[4]{0} constant({0, 0, 1, 1}), metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275823 1044120 spmd_partitioner.cc:3596]      * %dynamic-slice = u32[1]{0} dynamic-slice(u32[4]{0} %constant, u32[] %partition-id), dynamic_slice_sizes={1}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275825 1044120 spmd_partitioner.cc:3596]      * %reshape = u32[] reshape(u32[1]{0} %dynamic-slice), metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275827 1044120 spmd_partitioner.cc:3596]      * %constant = s32[4]{0} constant({0, 0, 1, 1}), metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275828 1044120 spmd_partitioner.cc:3596]      * %dynamic-slice = s32[1]{0} dynamic-slice(s32[4]{0} %constant, u32[] %partition-id), dynamic_slice_sizes={1}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275830 1044120 spmd_partitioner.cc:3596]      * %reshape = s32[] reshape(s32[1]{0} %dynamic-slice), metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275832 1044120 spmd_partitioner.cc:3596]      * %constant = s32[] constant(0), metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275833 1044120 spmd_partitioner.cc:3596]      * %constant = u32[4]{0} constant({0, 1, 2, 3}), metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275835 1044120 spmd_partitioner.cc:3596]      * %dynamic-slice = u32[1]{0} dynamic-slice(u32[4]{0} %constant, u32[] %partition-id), dynamic_slice_sizes={1}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275837 1044120 spmd_partitioner.cc:3596]      * %reshape = u32[] reshape(u32[1]{0} %dynamic-slice), metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275839 1044120 spmd_partitioner.cc:3596]      * %reshape = f32[1,4096,4096]{2,1,0} reshape(f32[4096,4096]{1,0} %param), metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275840 1044120 spmd_partitioner.cc:3596]      * %all-gather = f32[2,4096,4096]{2,1,0} all-gather(f32[1,4096,4096]{2,1,0} %reshape), channel_id=1, replica_groups={{0,1},{2,3}}, dimensions={0}, use_global_device_ids=true, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275842 1044120 spmd_partitioner.cc:3596]      * %transpose = f32[4096,2,4096]{2,0,1} transpose(f32[2,4096,4096]{2,1,0} %all-gather), dimensions={1,0,2}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275843 1044120 spmd_partitioner.cc:3596]      * %reshape = f32[4096,8192]{1,0} reshape(f32[4096,2,4096]{2,0,1} %transpose), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275845 1044120 spmd_partitioner.cc:3596]      * %reshape = f32[1,4096,4096]{2,1,0} reshape(f32[4096,4096]{1,0} %param), metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275847 1044120 spmd_partitioner.cc:3596]      * %all-gather = f32[2,4096,4096]{2,1,0} all-gather(f32[1,4096,4096]{2,1,0} %reshape), channel_id=2, replica_groups={{0,2},{1,3}}, dimensions={0}, use_global_device_ids=true, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275848 1044120 spmd_partitioner.cc:3596]      * %transpose = f32[2,4096,4096]{2,1,0} transpose(f32[2,4096,4096]{2,1,0} %all-gather), dimensions={0,1,2}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275850 1044120 spmd_partitioner.cc:3596]      * %reshape = f32[8192,4096]{1,0} reshape(f32[2,4096,4096]{2,1,0} %transpose), sharding={replicated}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275851 1044120 spmd_partitioner.cc:3596]      * %convolution = f32[4096,4096]{1,0} convolution(f32[4096,8192]{1,0} %reshape, f32[8192,4096]{1,0} %reshape), dim_labels=bf_io->bf, sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla___op_ShardedMatMul" op_name="xla___op_ShardedMatMul"}
I0615 14:52:19.275853 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.275854 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.275856 1044120 spmd_partitioner.cc:3596]   64.00MiB : %param = f32[4096,4096]{1,0} parameter(1), sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"}     * %param = f32[4096,4096]{1,0} parameter(1), sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
I0615 14:52:19.275858 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.275860 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.275862 1044120 spmd_partitioner.cc:3596]   64.00MiB : %param = f32[4096,4096]{1,0} parameter(0), sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"}     * %param = f32[4096,4096]{1,0} parameter(0), sharding={devices=[2,2]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
I0615 14:52:19.275863 1044120 spmd_partitioner.cc:3596]
I0615 14:52:19.277370 1044120 tpu_layout_assignment.cc:2037] Ran 2 additional passes of layout assignment to assign all layouts.
I0615 14:52:19.278897 1044120 window_config_assignment.cc:62] Retrieving backend configs from FDO profiles.
I0615 14:52:19.278940 1044120 2a886c8_compiler_base.cc:3218] XLA::TPU HLO PostOptimizationPipeline
I0615 14:52:20.832438 1044120 2a886c8_compiler_base.cc:1494] final program bundle count: 22,496 note this count does not reflect cycles spent executing delays.
I0615 14:52:20.836553 1044120 2a886c8_compiler_base.cc:1494] final program bundle count: 60 note this count does not reflect cycles spent executing delays.
I0615 14:52:20.869977 1044120 2a886c8_compiler_base.cc:1673] Program too large for IMEM. Divided into 2 overlays (946.0K).
I0615 14:52:20.877122 1044120 2a886c8_compiler_base.cc:1799] XLA::TPU program HBM usage: 512.97M / 15.48G
I0615 14:52:20.877145 1044120 2a886c8_compiler_base.cc:1829] XLA::TPU program VMEM usage: 15.00M / 16.00M
I0615 14:52:20.877158 1044120 2a886c8_compiler_base.cc:1840] Total hbm usage >= 1.14G:
I0615 14:52:20.877161 1044120 2a886c8_compiler_base.cc:1840]     reserved        530.00M
I0615 14:52:20.877162 1044120 2a886c8_compiler_base.cc:1840]     program         512.97M
I0615 14:52:20.877163 1044120 2a886c8_compiler_base.cc:1840]     arguments       128.00M
I0615 14:52:20.877165 1044120 2a886c8_compiler_base.cc:1840]
I0615 14:52:20.877167 1044120 2a886c8_compiler_base.cc:1840] Output size 256.00M; shares 0B with arguments.
I0615 14:52:20.877168 1044120 2a886c8_compiler_base.cc:1840]
I0615 14:52:20.877183 1044120 2a886c8_compiler_base.cc:1844] Program sflag requirement 128B:
I0615 14:52:20.877185 1044120 2a886c8_compiler_base.cc:1844]     reserved           100B
I0615 14:52:20.877187 1044120 2a886c8_compiler_base.cc:1844]     scoped              28B
I0615 14:52:20.877188 1044120 2a886c8_compiler_base.cc:1844] Program hbm requirement 512.97M:
I0615 14:52:20.877190 1044120 2a886c8_compiler_base.cc:1844]     global            52.0K
I0615 14:52:20.877191 1044120 2a886c8_compiler_base.cc:1844]     HLO temp        512.00M (100.0% utilization: Unpadded (512.00M) Padded (512.00M), 0.0% fragmentation (0B))
I0615 14:52:20.877193 1044120 2a886c8_compiler_base.cc:1844]     overlays         946.0K
I0615 14:52:20.877195 1044120 2a886c8_compiler_base.cc:1844] Program vmem requirement 15.00M:
I0615 14:52:20.877196 1044120 2a886c8_compiler_base.cc:1844]     scoped           15.00M
I0615 14:52:20.877197 1044120 2a886c8_compiler_base.cc:1844] Program smem requirement 2.0K:
I0615 14:52:20.877199 1044120 2a886c8_compiler_base.cc:1844]     scoped             2.0K
I0615 14:52:20.877200 1044120 2a886c8_compiler_base.cc:1852] XLA::TPU program SMEM usage: 2.3K / 16.0K (2 parameters)
I0615 14:52:20.880040 1044120 isa_program.cc:370] Executable fingerprint:875f932b296c245088309b0b2e4ada4df6bc2f54bdb249b50a9622cfbf651470
I0615 14:52:29.849659 1044716 futex.cc:60] RAW: Futex::Swap(): using FUTEX_WAKE + FUTEX_WAIT
```

## Sharded JIT graph HLO graph

In [17]:
from jax.experimental import sharded_jit
from jax.experimental import PartitionSpec as P
from jax._src.util import prod
from jax.lib import xla_bridge as jxb

In [18]:
p = P(2,1)
py_opsharding = jxb._sharding_to_proto(p)
py_opsharding


<jaxlib.xla_client.OpSharding at 0x7f66c41eec20>

In [19]:
def f(x, y):
  return x * y

sharded_f = sharded_jit(f, in_parts=(P(2, 1), P(2, 1)), out_parts=(P(2, 1)))

shape = (2, 2)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
print(x)
print(y)

z = sharded_f(x, y)
print(z)
jax_sharded_hlo = jax.xla_computation(sharded_f)(x, x).as_hlo_text()
print(jax_sharded_hlo)



[[0. 1.]
 [2. 3.]]
[[0. 1.]
 [2. 3.]]
[[0. 1.]
 [4. 9.]]
HloModule xla_computation_f.16

sharded_jit_f.4 {
  constant.9 = pred[] constant(false)
  parameter.5 = f32[2,2]{1,0} parameter(0)
  custom-call.6 = f32[2,2]{1,0} custom-call(parameter.5), custom_call_target="Sharding", sharding={devices=[2,1]0,1}
  parameter.7 = f32[2,2]{1,0} parameter(1)
  custom-call.8 = f32[2,2]{1,0} custom-call(parameter.7), custom_call_target="Sharding", sharding={devices=[2,1]0,1}
  multiply.10 = f32[2,2]{1,0} multiply(custom-call.6, custom-call.8)
  custom-call.11 = f32[2,2]{1,0} custom-call(multiply.10), custom_call_target="Sharding", sharding={devices=[2,1]0,1}
  ROOT tuple.12 = (f32[2,2]{1,0}) tuple(custom-call.11)
}

ENTRY xla_computation_f.16 {
  constant.3 = pred[] constant(false)
  parameter.1 = f32[2,2]{1,0} parameter(0)
  parameter.2 = f32[2,2]{1,0} parameter(1)
  call.13 = (f32[2,2]{1,0}) call(parameter.1, parameter.2), to_apply=sharded_jit_f.4
  get-tuple-element.14 = f32[2,2]{1,0} get-tuple-el