In [1]:
import os
os.environ["XLA_FLAGS"] = "--xla_dump_to=/home/bbahl/hlo"
from torch import nn
import torch
import torch_xla
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr

xr.use_spmd()
mesh = xs.get_1d_mesh('a')
xs.set_global_mesh(mesh)



In [2]:
class MarkShardingFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor, partition_spec):
        """
        Forward pass: Mark the input tensor with sharding annotation.
        """
        ctx.partition_spec = partition_spec
        mesh = xs.get_global_mesh()
        print("DEBUG running forward")
        xs.mark_sharding(input, mesh, partition_spec)
        return input


    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: Mark the gradient with sharding annotation
        """
        print("DEBUG running backward")
        partition_spec = ctx.partition_spec
        mesh = xs.get_global_mesh()
        zero = torch.zeros((1,), device=grad_output.device, dtype=grad_output.dtype)
        new_grad_output = zero+grad_output
        xs.mark_sharding(new_grad_output, mesh, partition_spec)
        return new_grad_output, None, None
        

In [3]:
class SimpleLinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(128, 128)
        self.w2 = nn.Linear(128, 256)

    def forward(self, x):
        out = self.w1(x)
        out.retain_grad()
        MarkShardingFunction.apply(out, (None, 'a'))
        out = self.w2(out)
        return out

In [4]:
model = SimpleLinearModel()
model = model.to(torch.bfloat16).to('xla')
x = torch.randn((10, 128), dtype=torch.bfloat16).to('xla')
y = model(x)
labels = torch.ones((10, 256), dtype=torch.bfloat16).to('xla')
loss_func = nn.MSELoss()
loss = loss_func(y, labels)
loss.backward()

DEBUG running forward


In [None]:
# Use the custom function
x = torch.randn(8, 8, requires_grad=True)
x = x.to('xla')
partition_spec = ('a', None)
z = x@x
z.retain_grad()
y = MarkShardingFunction.apply (z, partition_spec)  # Forward pass
t = y.sum()
t.backward()
print(torch_xla._XLAC._get_xla_tensors_hlo([z.grad]))
torch_xla.sync()

In [None]:
!ls ../hlo

In [None]:
!cat ../hlo/module_0041.ReplicateShardedData.6.before_optimizations.txt