<a href="https://colab.research.google.com/github/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Installing PyTorch/XLA

Run the following cell (or copy it into your own notebook!) to install PyTorch, Torchvision, and PyTorch/XLA. It will take a couple minutes to run.

In [1]:
VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  3727  100  3727    0     0  14334      0 --:--:-- --:--:-- --:--:-- 14389
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Uninstalling torch-1.4.0:
Done updating TPU runtime: <Response [200]>
  Successfully uninstalled torch-1.4.0
Uninstalling torchvision-0.5.0:
  Successfully uninstalled torchvision-0.5.0
Copying gs://tpu-pytorch/wheels/torch-nightly-cp36-cp36m-linux_x86_64.whl...
/ [1 files][ 86.7 MiB/ 86.7 MiB]                                                
Operation completed over 1 objects/86.7 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp36-cp36m-linux_x86_64.whl...
- [1 files][116.0 MiB/116.0 MiB]                                               

In [84]:
# all_gather with sub-groups via all_reduce
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

groups = [
  [0, 1],
  [2, 3],
  [4, 5],
  [6, 7],
]
group_size = len(groups[0])

def all_gather(t: torch.Tensor) -> torch.Tensor:
  s = torch.zeros((group_size,) + t.shape, dtype=torch.float32)
  s[xm.get_ordinal() % group_size] = t
  s = s.to(xm.xla_device())
  xm.all_reduce('sum', [s], groups=groups)
  return s

def _mp_fn(rank: int):
  torch.manual_seed(1)

  d = xm.xla_device()
  t = torch.Tensor([xm.get_ordinal()] * 4)

  print(f'ordinal={xm.get_ordinal()}, t={t}')
  t_r = all_gather(t)
  print(f'ordinal={xm.get_ordinal()}, all-reduced t_r={t_r}\n')

  xm.rendezvous('init')

xmp.spawn(_mp_fn, nprocs=8, start_method='fork')

ordinal=0, t=tensor([0., 0., 0., 0.])
ordinal=4, t=tensor([4., 4., 4., 4.])
ordinal=1, t=tensor([1., 1., 1., 1.])
ordinal=2, t=tensor([2., 2., 2., 2.])
ordinal=6, t=tensor([6., 6., 6., 6.])
ordinal=5, t=tensor([5., 5., 5., 5.])
ordinal=3, t=tensor([3., 3., 3., 3.])
ordinal=7, t=tensor([7., 7., 7., 7.])
ordinal=2, all-reduced t_r=tensor([[2., 2., 2., 2.],
        [3., 3., 3., 3.]], device='xla:0')
ordinal=4, all-reduced t_r=tensor([[4., 4., 4., 4.],
        [5., 5., 5., 5.]], device='xla:0')
ordinal=7, all-reduced t_r=tensor([[6., 6., 6., 6.],
        [7., 7., 7., 7.]], device='xla:0')
ordinal=0, all-reduced t_r=tensor([[0., 0., 0., 0.],
        [1., 1., 1., 1.]], device='xla:1')
ordinal=6, all-reduced t_r=tensor([[6., 6., 6., 6.],
        [7., 7., 7., 7.]], device='xla:0')
ordinal=3, all-reduced t_r=tensor([[2., 2., 2., 2.],
        [3., 3., 3., 3.]], device='xla:0')
ordinal=5, all-reduced t_r=tensor([[4., 4., 4., 4.],
        [5., 5., 5., 5.]], device='xla:0')
ordinal=1, all-reduced t

In [86]:
# Broadcast
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

groups = [
  [0, 1, 2, 3],
  [4, 5, 6, 7],
]
group_size = len(groups[0])
size = 4

# Note: with xla::CollectivePermute it should be also possible
# to implement with something like (per group):
#
#   source_target_pairs = [[0,0], [0,1], [0,2], [0,3]]
#

def _broadcast_fn(rank: int):
  torch.manual_seed(1)

  group_num = int(xm.get_ordinal() / group_size)
  nsize = size * (group_num + 1)
  t_cpu_0 = torch.linspace(start=0, end=nsize-1, steps=nsize) # broadcaster
  t_cpu_i = torch.zeros(nsize) # receiver

  # t_cpu for each ordinal:
  # 0: tensor([0., 1., 2., 3.])
  # 1: tensor([0., 0., 0., 0.])
  # 2: tensor([0., 0., 0., 0.])
  # 3: tensor([0., 0., 0., 0.])
  # 4: tensor([0., 1., 2., 3., 4., 5., 6., 7.])
  # 5: tensor([0., 0., 0., 0., 0., 0., 0., 0.])
  # 6: tensor([0., 0., 0., 0., 0., 0., 0., 0.])
  # 7: tensor([0., 0., 0., 0., 0., 0., 0., 0.])
  t_cpu = t_cpu_0 if xm.get_ordinal() % group_size == 0 else t_cpu_i
  t = t_cpu.to(xm.xla_device())

  print(f'ordinal={xm.get_ordinal()}, t={t}')
  xm.all_reduce('sum', [t], groups=groups)
  print(f'ordinal={xm.get_ordinal()}, all-reduced t={t}')

  if t_cpu_0.tolist() != t.cpu().tolist():
    print(f'Wrong result from core {xm.get_ordinal()}: {t}')
  else:
    print('ok')

  xm.rendezvous('init')

xmp.spawn(_broadcast_fn, nprocs=8, start_method='fork')

ordinal=0, t=tensor([0., 1., 2., 3.], device='xla:1')
ordinal=5, t=tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='xla:0')
ordinal=1, t=tensor([0., 0., 0., 0.], device='xla:0')
ordinal=3, t=tensor([0., 0., 0., 0.], device='xla:0')
ordinal=7, t=tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='xla:0')
ordinal=4, t=tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='xla:0')
ordinal=2, t=tensor([0., 0., 0., 0.], device='xla:0')
ordinal=6, t=tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='xla:0')
ordinal=0, all-reduced t=tensor([0., 1., 2., 3.], device='xla:1')
ordinal=4, all-reduced t=tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='xla:0')
ordinal=7, all-reduced t=tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='xla:0')
ordinal=5, all-reduced t=tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='xla:0')
ordinal=6, all-reduced t=tensor([0., 1., 2., 3., 4., 5., 6., 7.], device='xla:0')
ordinal=3, all-reduced t=tensor([0., 1., 2., 3.], device='xla:0')
ordinal=2, all-reduced t=tensor([0., 1., 2.,

In [87]:
# Scatter (broadcast + _split)

groups = [
  [0, 1, 2, 3],
  [4, 5, 6, 7],
]
group_size = len(groups[0])


def _scatter_fn(rank):
  torch.manual_seed(1)

  group_num = int(xm.get_ordinal() / group_size)
  t_cpu_0 = torch.linspace(
      start=group_num*group_size, end=(group_num+1)*group_size-1, steps=group_size) # broadcaster
  t_cpu_i = torch.zeros(group_size) # receiver
  t = t_cpu_0 if xm.get_ordinal() % group_size == 0 else t_cpu_i
  t = t.to(xm.xla_device())

  print(f'ordinal={xm.get_ordinal()}, t={t}')
  xm.all_reduce('sum', [t], groups=groups) # bcast
  t = t[xm.get_ordinal() % group_size] # _split
  print(f'ordinal={xm.get_ordinal()}, bcasted t={t}\n')

  xm.rendezvous('init')

xmp.spawn(_scatter_fn, nprocs=8, start_method='fork')

ordinal=0, t=tensor([0., 1., 2., 3.], device='xla:1')
ordinal=6, t=tensor([0., 0., 0., 0.], device='xla:0')
ordinal=7, t=tensor([0., 0., 0., 0.], device='xla:0')
ordinal=5, t=tensor([0., 0., 0., 0.], device='xla:0')
ordinal=1, t=tensor([0., 0., 0., 0.], device='xla:0')
ordinal=3, t=tensor([0., 0., 0., 0.], device='xla:0')
ordinal=4, t=tensor([4., 5., 6., 7.], device='xla:0')
ordinal=2, t=tensor([0., 0., 0., 0.], device='xla:0')
ordinal=0, bcasted t=0.0
ordinal=5, bcasted t=5.0
ordinal=1, bcasted t=1.0
ordinal=6, bcasted t=6.0
ordinal=2, bcasted t=2.0
ordinal=4, bcasted t=4.0






ordinal=3, bcasted t=3.0
ordinal=7, bcasted t=7.0




In [0]:
# [WIP - AllToAll API pending] all_gather with sub-groups via all_to_all
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

groups = [
  [0, 1],
  [2, 3],
  [4, 5],
  [6, 7],
]
group_size = len(groups[0])

def all_gather(t: torch.Tensor) -> torch.Tensor:
  t = t.to(xm.xla_device())
  s = xm.all_to_all(
      t,
      split_dimension=0,
      concat_dimension=0,
      split_count=4,
      groups=groups)
  return s

def _mp_fn(rank: int):
  torch.manual_seed(1)

  d = xm.xla_device()
  t = torch.Tensor([xm.get_ordinal()] * 4)

  print(f'ordinal={xm.get_ordinal()}, t={t}')
  t_r = all_gather(t)
  print(f'ordinal={xm.get_ordinal()}, all-reduced t_r={t_r}\n')

  xm.rendezvous('init')

xmp.spawn(_mp_fn, nprocs=8, start_method='fork')