In [5]:
import os
os.environ["XLA_FLAGS"] = "--xla_dump_to=/home/bbahl/hlo"

In [None]:
import torch
import torch_xla
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr

xr.use_spmd()

class MarkShardingFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor, partition_spec):
        """
        Forward pass: Mark the input tensor with sharding annotation.
        """
        ctx.partition_spec = partition_spec
        xs.mark_sharding(input, mesh, partition_spec)
        return input


    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: Mark the gradient with sharding annotation
        """
        print("DEBUG running backward")
        partition_spec = ctx.partition_spec
        zero = torch.zeros((1,), device=grad_output.device, dtype=grad_output.dtype)
        new_grad_output = zero+grad_output
        xs.mark_sharding(new_grad_output, mesh, partition_spec)
        return new_grad_output, None
        

# Use the custom function
x = torch.randn(8, 8, requires_grad=True)
x = x.to('xla')
mark_sharding = MarkShardingFunction.apply  # Shortcut to call the function
mesh = xs.get_1d_mesh('a')
xs.set_global_mesh(mesh)
partition_spec = ('a', None)
z = x@x
z.retain_grad()
y = mark_sharding(z, partition_spec)  # Forward pass
t = y.sum()
t.backward()
print(torch_xla._XLAC._get_xla_tensors_hlo([z.grad]))
torch_xla.sync()

DEBUG running backward
HloModule IrToHlo.19, entry_computation_layout={()->(f32[8,8]{1,0})}

ENTRY %IrToHlo.19 () -> (f32[8,8]) {
  %constant.8 = f32[] constant(0)
  %reshape.9 = f32[1]{0} reshape(f32[] %constant.8)
  %broadcast.10 = f32[1]{0} broadcast(f32[1]{0} %reshape.9), dimensions={0}
  %broadcast.13 = f32[8,1]{1,0} broadcast(f32[1]{0} %broadcast.10), dimensions={1}
  %reshape.14 = f32[8]{0} reshape(f32[8,1]{1,0} %broadcast.13)
  %broadcast.15 = f32[8,8]{1,0} broadcast(f32[8]{0} %reshape.14), dimensions={0}
  %constant.2 = f32[] constant(1)
  %broadcast.3 = f32[] broadcast(f32[] %constant.2), dimensions={}
  %reshape.4 = f32[1,1]{1,0} reshape(f32[] %broadcast.3)
  %broadcast.5 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.4), dimensions={0,1}
  %reshape.6 = f32[] reshape(f32[1,1]{1,0} %broadcast.5)
  %broadcast.7 = f32[8,8]{1,0} broadcast(f32[] %reshape.6), dimensions={}
  %constant.1 = f32[] constant(1)
  %broadcast.11 = f32[8,8]{1,0} broadcast(f32[] %constant.1), dimensions=



In [18]:
!ls ../hlo

module_0037.SyncTensorsGraph.28.after_codegen.txt
module_0037.SyncTensorsGraph.28.after_optimizations-buffer-assignment.txt
module_0037.SyncTensorsGraph.28.after_optimizations-memory-usage-report.txt
module_0037.SyncTensorsGraph.28.after_optimizations.txt
module_0037.SyncTensorsGraph.28.after_optimizations_after_buffer_assignment.txt
module_0037.SyncTensorsGraph.28.after_optimizations_before_buffer_assignment.txt
module_0037.SyncTensorsGraph.28.before_optimizations.txt
module_0037.SyncTensorsGraph.28.execution_options.txt
module_0037.SyncTensorsGraph.28.flagfile
module_0037.SyncTensorsGraph.28.hlo_module_config.txt
module_0037.SyncTensorsGraph.28.target_arguments.txt
module_0037.SyncTensorsGraph.28.tpu_comp_env.txt
module_0037.SyncTensorsGraph.28.transfer_stats.txt
module_0041.ReplicateShardedData.6.after_codegen.txt
module_0041.ReplicateShardedData.6.after_optimizations-buffer-assignment.txt
module_0041.ReplicateShardedData.6.after_optimizations-memory-usage-report.txt
module_0041.Rep

In [20]:
!cat ../hlo/module_0041.ReplicateShardedData.6.before_optimizations.txt

HloModule ReplicateShardedData.6, entry_computation_layout={(f32[8,8]{1,0:T(2,128)})->f32[8,8]{1,0:T(8,128)}}, num_partitions=4

ENTRY ReplicateShardedData.6 {
  p0.1 = f32[8,8]{1,0} parameter(0), sharding={devices=[4,1]0,1,2,3}
  constant.2 = s32[] constant(0), sharding={replicated}
  convert.3 = f32[] convert(constant.2), sharding={replicated}
  broadcast.4 = f32[8,8]{1,0} broadcast(convert.3), dimensions={}, sharding={replicated}
  ROOT add.5 = f32[8,8]{1,0} add(p0.1, broadcast.4), sharding={replicated}
}

